fix to jamba config, asserting attention and expert offset (#33316)
* fix to jamba config, asserting attention and expert offset * fix foramtting * fix foramtting * fix foramtting * changed to error raise instead of assertion, added unittests * fix * changed t_ to property_ * changed t_ to property_ * quickfix * ran code styler
This commit is contained in:
@@ -193,6 +193,9 @@ class JambaConfig(PretrainedConfig):
|
||||
self.attn_layer_period = attn_layer_period
|
||||
self.attn_layer_offset = attn_layer_offset
|
||||
|
||||
self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset)
|
||||
self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset)
|
||||
|
||||
self.use_mamba_kernels = use_mamba_kernels
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
@@ -222,3 +225,9 @@ class JambaConfig(PretrainedConfig):
|
||||
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
def _check_supported_offset(self, property_: str, period: int, offset: int):
|
||||
if offset >= period:
|
||||
raise ValueError(
|
||||
f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})"
|
||||
)
|
||||
|
||||
@@ -50,6 +50,48 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
class JambaConfigTester(ConfigTester):
|
||||
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["attn_layer_offset"] = attn_layer_offset
|
||||
_input_dict["attn_layer_period"] = attn_layer_period
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
|
||||
_input_dict = self.inputs_dict.copy()
|
||||
_input_dict["expert_layer_offset"] = expert_layer_offset
|
||||
_input_dict["expert_layer_period"] = expert_layer_period
|
||||
return self.config_class(**_input_dict)
|
||||
|
||||
def test_attn_offsets(self):
|
||||
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
|
||||
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)
|
||||
|
||||
def test_expert_offsets(self):
|
||||
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
|
||||
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
|
||||
with self.parent.assertRaises(ValueError):
|
||||
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)
|
||||
|
||||
def test_jamba_offset_properties(self):
|
||||
self.test_attn_offsets()
|
||||
self.test_expert_offsets()
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_jamba_offset_properties()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
class JambaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = JambaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
||||
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
Reference in New Issue
Block a user