fix to jamba config, asserting attention and expert offset (#33316)
* fix to jamba config, asserting attention and expert offset * fix foramtting * fix foramtting * fix foramtting * changed to error raise instead of assertion, added unittests * fix * changed t_ to property_ * changed t_ to property_ * quickfix * ran code styler
This commit is contained in:
@@ -193,6 +193,9 @@ class JambaConfig(PretrainedConfig):
|
|||||||
self.attn_layer_period = attn_layer_period
|
self.attn_layer_period = attn_layer_period
|
||||||
self.attn_layer_offset = attn_layer_offset
|
self.attn_layer_offset = attn_layer_offset
|
||||||
|
|
||||||
|
self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset)
|
||||||
|
self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset)
|
||||||
|
|
||||||
self.use_mamba_kernels = use_mamba_kernels
|
self.use_mamba_kernels = use_mamba_kernels
|
||||||
self.mamba_d_state = mamba_d_state
|
self.mamba_d_state = mamba_d_state
|
||||||
self.mamba_d_conv = mamba_d_conv
|
self.mamba_d_conv = mamba_d_conv
|
||||||
@@ -222,3 +225,9 @@ class JambaConfig(PretrainedConfig):
|
|||||||
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
|
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
|
||||||
for i in range(self.num_hidden_layers)
|
for i in range(self.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _check_supported_offset(self, property_: str, period: int, offset: int):
|
||||||
|
if offset >= period:
|
||||||
|
raise ValueError(
|
||||||
|
f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})"
|
||||||
|
)
|
||||||
|
|||||||
@@ -50,6 +50,48 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JambaConfigTester(ConfigTester):
|
||||||
|
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
|
||||||
|
_input_dict = self.inputs_dict.copy()
|
||||||
|
_input_dict["attn_layer_offset"] = attn_layer_offset
|
||||||
|
_input_dict["attn_layer_period"] = attn_layer_period
|
||||||
|
return self.config_class(**_input_dict)
|
||||||
|
|
||||||
|
def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
|
||||||
|
_input_dict = self.inputs_dict.copy()
|
||||||
|
_input_dict["expert_layer_offset"] = expert_layer_offset
|
||||||
|
_input_dict["expert_layer_period"] = expert_layer_period
|
||||||
|
return self.config_class(**_input_dict)
|
||||||
|
|
||||||
|
def test_attn_offsets(self):
|
||||||
|
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
|
||||||
|
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
|
||||||
|
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
|
||||||
|
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)
|
||||||
|
|
||||||
|
def test_expert_offsets(self):
|
||||||
|
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
|
||||||
|
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
|
||||||
|
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
|
||||||
|
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
|
||||||
|
with self.parent.assertRaises(ValueError):
|
||||||
|
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)
|
||||||
|
|
||||||
|
def test_jamba_offset_properties(self):
|
||||||
|
self.test_attn_offsets()
|
||||||
|
self.test_expert_offsets()
|
||||||
|
|
||||||
|
def run_common_tests(self):
|
||||||
|
self.test_jamba_offset_properties()
|
||||||
|
return super().run_common_tests()
|
||||||
|
|
||||||
|
|
||||||
class JambaModelTester:
|
class JambaModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = JambaModelTester(self)
|
self.model_tester = JambaModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|||||||
Reference in New Issue
Block a user