[ModernBERT] Never save 'reference_compile' config; should be set based on end user (#36305)
* Never save 'reference_compile' config; should be set based on end user * Reformat (I ran 'make style' from the wrong env) * Use pop instead of del Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Use pop instead of del Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -214,5 +214,10 @@ class ModernBertConfig(PretrainedConfig):
|
||||
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("reference_compile", None)
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["ModernBertConfig"]
|
||||
|
||||
@@ -248,6 +248,11 @@ class ModernBertConfig(PretrainedConfig):
|
||||
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
output.pop("reference_compile", None)
|
||||
return output
|
||||
|
||||
|
||||
def _unpad_modernbert_input(
|
||||
inputs: torch.Tensor,
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -366,6 +368,14 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
def test_flash_attn_2_conversion(self):
|
||||
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
|
||||
|
||||
def test_saved_config_excludes_reference_compile(self):
|
||||
config = ModernBertConfig(reference_compile=True)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
with open(os.path.join(tmpdirname, "config.json"), "r") as f:
|
||||
config_dict = json.load(f)
|
||||
self.assertNotIn("reference_compile", config_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModernBertModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user