fix to jamba config, asserting attention and expert offset (#33316)

* fix to jamba config, asserting attention and expert offset

* fix foramtting

* fix foramtting

* fix foramtting

* changed to error raise instead of assertion, added unittests

* fix

* changed t_ to property_

* changed t_ to property_

* quickfix

* ran code styler
This commit is contained in:
ErezSC42
2024-09-17 21:29:27 +03:00
committed by GitHub
parent 3476c19e91
commit 46c27577b3
2 changed files with 52 additions and 1 deletions

View File

@@ -50,6 +50,48 @@ if is_torch_available():
)
class JambaConfigTester(ConfigTester):
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["attn_layer_offset"] = attn_layer_offset
_input_dict["attn_layer_period"] = attn_layer_period
return self.config_class(**_input_dict)
def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["expert_layer_offset"] = expert_layer_offset
_input_dict["expert_layer_period"] = expert_layer_period
return self.config_class(**_input_dict)
def test_attn_offsets(self):
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)
def test_expert_offsets(self):
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)
def test_jamba_offset_properties(self):
self.test_attn_offsets()
self.test_expert_offsets()
def run_common_tests(self):
self.test_jamba_offset_properties()
return super().run_common_tests()
class JambaModelTester:
def __init__(
self,
@@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def setUp(self):
self.model_tester = JambaModelTester(self)
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()