Add config validation and style tweaks (#37589)

* Add config validation and style tweaks

* Fix style issues

* Fix style issues

* style

* Small fixes for copy/paste errors

---------

Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
This commit is contained in:
Kirire
2025-05-14 14:22:10 +02:00
committed by GitHub
parent 1b00966395
commit 935bbbc711
3 changed files with 44 additions and 21 deletions

View File

@@ -44,6 +44,29 @@ if is_torch_available():
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
class Mamba2ConfigTester(ConfigTester):
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
_input_dict = self.inputs_dict.copy()
_input_dict["hidden_size"] = hidden_size
_input_dict["num_heads"] = num_heads
_input_dict["expand"] = expand
_input_dict["head_dim"] = head_dim
return self.config_class(**_input_dict)
def test_hidden_size_compatibility(self):
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)
def run_common_tests(self):
self.test_hidden_size_compatibility()
return super().run_common_tests()
class Mamba2ModelTester:
def __init__(
self,
@@ -233,7 +256,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def setUp(self):
self.model_tester = Mamba2ModelTester(self)
self.config_tester = ConfigTester(
self.config_tester = Mamba2ConfigTester(
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)