[ModernBERT] Never save 'reference_compile' config; should be set based on end user (#36305)

* Never save 'reference_compile' config; should be set based on end user

* Reformat (I ran 'make style' from the wrong env)

* Use pop instead of del

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Use pop instead of del

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Tom Aarsen
2025-04-01 14:14:39 +02:00
committed by GitHub
parent c0bd8048a5
commit 897ff9af0e
3 changed files with 20 additions and 0 deletions

View File

@@ -214,5 +214,10 @@ class ModernBertConfig(PretrainedConfig):
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
) )
def to_dict(self):
output = super().to_dict()
output.pop("reference_compile", None)
return output
__all__ = ["ModernBertConfig"] __all__ = ["ModernBertConfig"]

View File

@@ -248,6 +248,11 @@ class ModernBertConfig(PretrainedConfig):
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
) )
def to_dict(self):
output = super().to_dict()
output.pop("reference_compile", None)
return output
def _unpad_modernbert_input( def _unpad_modernbert_input(
inputs: torch.Tensor, inputs: torch.Tensor,

View File

@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import tempfile
import unittest import unittest
import pytest import pytest
@@ -366,6 +368,14 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def test_flash_attn_2_conversion(self): def test_flash_attn_2_conversion(self):
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.") self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
def test_saved_config_excludes_reference_compile(self):
config = ModernBertConfig(reference_compile=True)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
with open(os.path.join(tmpdirname, "config.json"), "r") as f:
config_dict = json.load(f)
self.assertNotIn("reference_compile", config_dict)
@require_torch @require_torch
class ModernBertModelIntegrationTest(unittest.TestCase): class ModernBertModelIntegrationTest(unittest.TestCase):