[MptConfig] support from pretrained args (#25116)
* support from pretrained args * draft addition of tests * update test * use parrent assert true * Update src/transformers/models/mpt/configuration_mpt.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -327,6 +327,20 @@ class MptModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class MptConfigTester(ConfigTester):
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
|
||||
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
|
||||
|
||||
def test_attn_config_as_dict(self):
|
||||
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
|
||||
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
|
||||
self.parent.assertTrue(config.attn_config.softmax_scale is None)
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_attn_config_as_dict()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
@require_torch
|
||||
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
@@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MptModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
|
||||
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
Reference in New Issue
Block a user