[ModernBERT] Never save 'reference_compile' config; should be set based on end user (#36305)
* Never save 'reference_compile' config; should be set based on end user * Reformat (I ran 'make style' from the wrong env) * Use pop instead of del Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Use pop instead of del Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -214,5 +214,10 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
output = super().to_dict()
|
||||||
|
output.pop("reference_compile", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ModernBertConfig"]
|
__all__ = ["ModernBertConfig"]
|
||||||
|
|||||||
@@ -248,6 +248,11 @@ class ModernBertConfig(PretrainedConfig):
|
|||||||
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
output = super().to_dict()
|
||||||
|
output.pop("reference_compile", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _unpad_modernbert_input(
|
def _unpad_modernbert_input(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
|
|||||||
@@ -12,7 +12,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -366,6 +368,14 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
def test_flash_attn_2_conversion(self):
|
def test_flash_attn_2_conversion(self):
|
||||||
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
|
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
|
||||||
|
|
||||||
|
def test_saved_config_excludes_reference_compile(self):
|
||||||
|
config = ModernBertConfig(reference_compile=True)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
config.save_pretrained(tmpdirname)
|
||||||
|
with open(os.path.join(tmpdirname, "config.json"), "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
self.assertNotIn("reference_compile", config_dict)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModernBertModelIntegrationTest(unittest.TestCase):
|
class ModernBertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user