fix to jamba config, asserting attention and expert offset (#33316)
* fix to jamba config, asserting attention and expert offset * fix foramtting * fix foramtting * fix foramtting * changed to error raise instead of assertion, added unittests * fix * changed t_ to property_ * changed t_ to property_ * quickfix * ran code styler
This commit is contained in:
@@ -50,6 +50,48 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
class JambaConfigTester(ConfigTester):
|
||||
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["attn_layer_offset"] = attn_layer_offset
|
||||
_input_dict["attn_layer_period"] = attn_layer_period
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["expert_layer_offset"] = expert_layer_offset
|
||||
_input_dict["expert_layer_period"] = expert_layer_period
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def test_attn_offsets(self):
|
||||
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)
|
||||
|
||||
def test_expert_offsets(self):
|
||||
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)
|
||||
|
||||
def test_jamba_offset_properties(self):
|
||||
self.test_attn_offsets()
|
||||
self.test_expert_offsets()
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_jamba_offset_properties()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
class JambaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = JambaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
||||
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
Reference in New Issue
Block a user