From 897ff9af0e8892167af1eb4ec58677001c3a0041 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 1 Apr 2025 14:14:39 +0200 Subject: [PATCH] [`ModernBERT`] Never save 'reference_compile' config; should be set based on end user (#36305) * Never save 'reference_compile' config; should be set based on end user * Reformat (I ran 'make style' from the wrong env) * Use pop instead of del Co-authored-by: Matt * Use pop instead of del Co-authored-by: Matt --------- Co-authored-by: Matt --- .../models/modernbert/configuration_modernbert.py | 5 +++++ .../models/modernbert/modular_modernbert.py | 5 +++++ tests/models/modernbert/test_modeling_modernbert.py | 10 ++++++++++ 3 files changed, 20 insertions(+) diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index cc0295c25b..1835f55aae 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -214,5 +214,10 @@ class ModernBertConfig(PretrainedConfig): f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) + def to_dict(self): + output = super().to_dict() + output.pop("reference_compile", None) + return output + __all__ = ["ModernBertConfig"] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 0901662f66..934931da3b 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -248,6 +248,11 @@ class ModernBertConfig(PretrainedConfig): f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' ) + def to_dict(self): + output = super().to_dict() + output.pop("reference_compile", None) + return output + def _unpad_modernbert_input( inputs: torch.Tensor, diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 14882b0879..82a0f85052 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import tempfile import unittest import pytest @@ -366,6 +368,14 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa def test_flash_attn_2_conversion(self): self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.") + def test_saved_config_excludes_reference_compile(self): + config = ModernBertConfig(reference_compile=True) + with tempfile.TemporaryDirectory() as tmpdirname: + config.save_pretrained(tmpdirname) + with open(os.path.join(tmpdirname, "config.json"), "r") as f: + config_dict = json.load(f) + self.assertNotIn("reference_compile", config_dict) + @require_torch class ModernBertModelIntegrationTest(unittest.TestCase):