[MptConfig] support from pretrained args (#25116)

* support from pretrained args

* draft addition of tests

* update test

* use parrent assert true

* Update src/transformers/models/mpt/configuration_mpt.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Arthur
2023-07-27 16:24:52 +02:00
committed by GitHub
parent a1c4954d25
commit 9cea3e7b80
2 changed files with 46 additions and 7 deletions

View File

@@ -327,6 +327,20 @@ class MptModelTester:
return config, inputs_dict
class MptConfigTester(ConfigTester):
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
def test_attn_config_as_dict(self):
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
self.parent.assertTrue(config.attn_config.softmax_scale is None)
def run_common_tests(self):
self.test_attn_config_as_dict()
return super().run_common_tests()
@require_torch
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def setUp(self):
self.model_tester = MptModelTester(self)
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()