Config: warning when saving generation kwargs in the model config (#28514)
This commit is contained in:
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.hidden_size, 768)
|
||||
|
||||
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
|
||||
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("min_length", logs.output[0])
|
||||
|
||||
def test_has_non_default_generation_parameters(self):
|
||||
config = BertConfig()
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
config = BertConfig(min_length=3)
|
||||
self.assertTrue(config._has_non_default_generation_parameters())
|
||||
config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
|
||||
Reference in New Issue
Block a user