[MptConfig] support from pretrained args (#25116)
* support from pretrained args * draft addition of tests * update test * use parrent assert true * Update src/transformers/models/mpt/configuration_mpt.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -101,6 +101,23 @@ class MptAttentionConfig(PretrainedConfig):
|
|||||||
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
|
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
|
||||||
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
if config_dict.get("model_type") == "mpt":
|
||||||
|
config_dict = config_dict["attn_config"]
|
||||||
|
|
||||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
|
logger.warning(
|
||||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class MptConfig(PretrainedConfig):
|
class MptConfig(PretrainedConfig):
|
||||||
"""
|
"""
|
||||||
@@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig):
|
|||||||
"hidden_size": "d_model",
|
"hidden_size": "d_model",
|
||||||
"num_hidden_layers": "n_layers",
|
"num_hidden_layers": "n_layers",
|
||||||
}
|
}
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -204,6 +222,7 @@ class MptConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self.attn_config = attn_config
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
@@ -222,20 +241,25 @@ class MptConfig(PretrainedConfig):
|
|||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_config(self):
|
||||||
|
return self._attn_config
|
||||||
|
|
||||||
|
@attn_config.setter
|
||||||
|
def attn_config(self, attn_config):
|
||||||
if attn_config is None:
|
if attn_config is None:
|
||||||
self.attn_config = MptAttentionConfig()
|
self._attn_config = MptAttentionConfig()
|
||||||
elif isinstance(attn_config, dict):
|
elif isinstance(attn_config, dict):
|
||||||
self.attn_config = MptAttentionConfig(**attn_config)
|
self._attn_config = MptAttentionConfig(**attn_config)
|
||||||
elif isinstance(attn_config, MptAttentionConfig):
|
elif isinstance(attn_config, MptAttentionConfig):
|
||||||
self.attn_config = attn_config
|
self._attn_config = attn_config
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
|
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||||
@@ -245,7 +269,8 @@ class MptConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
output = copy.deepcopy(self.__dict__)
|
output = copy.deepcopy(self.__dict__)
|
||||||
output["attn_config"] = (
|
output["attn_config"] = (
|
||||||
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
|
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
|
||||||
)
|
)
|
||||||
|
del output["_attn_config"]
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -327,6 +327,20 @@ class MptModelTester:
|
|||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MptConfigTester(ConfigTester):
|
||||||
|
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
|
||||||
|
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
|
||||||
|
|
||||||
|
def test_attn_config_as_dict(self):
|
||||||
|
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
|
||||||
|
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
|
||||||
|
self.parent.assertTrue(config.attn_config.softmax_scale is None)
|
||||||
|
|
||||||
|
def run_common_tests(self):
|
||||||
|
self.test_attn_config_as_dict()
|
||||||
|
return super().run_common_tests()
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
@@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = MptModelTester(self)
|
self.model_tester = MptModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
|
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|||||||
Reference in New Issue
Block a user