[PretrainedConfig] Fix save pretrained config for edge case (#7943)

* fix config save

* add test

* add config class variable and another test

* line break

* fix fsmt and typo

* god am I making many errors today :-/

* Update src/transformers/configuration_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-10-22 15:39:01 +02:00
committed by GitHub
parent cc2e312ca3
commit f34372a9ff
6 changed files with 34 additions and 4 deletions

View File

@@ -901,6 +901,15 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_model_with_attn_mask(*config_and_inputs)
def test_config_save(self):
config = self.model_tester.prepare_config_and_inputs()[0]
config.add_cross_attention = False
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)
config = ProphetNetConfig.from_pretrained(tmp_dirname)
self.assertFalse(config.add_cross_attention)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()