Add config validation and style tweaks (#37589)
* Add config validation and style tweaks * Fix style issues * Fix style issues * style * Small fixes for copy/paste errors --------- Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
@@ -44,6 +44,29 @@ if is_torch_available():
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||
|
||||
|
||||
class Mamba2ConfigTester(ConfigTester):
|
||||
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["hidden_size"] = hidden_size
|
||||
_input_dict["num_heads"] = num_heads
|
||||
_input_dict["expand"] = expand
|
||||
_input_dict["head_dim"] = head_dim
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def test_hidden_size_compatibility(self):
|
||||
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
|
||||
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
|
||||
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_hidden_size_compatibility()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
class Mamba2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Mamba2ModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self.config_tester = Mamba2ConfigTester(
|
||||
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user