Refactor DBRX tests to use CausalLMModelTest base classes (#38475)

* Refactor DBRX tests to use CausalLMModelTest base classes

- Changed DbrxModelTester to inherit from CausalLMModelTester
- Changed DbrxModelTest to inherit from CausalLMModelTest
- Removed duplicate methods that are already in base classes
- Added required class attributes for model classes
- Updated pipeline_model_mapping to include feature-extraction
- Kept DBRX-specific configuration and test methods
- Disabled RoPE tests as DBRX's rotary embedding doesn't accept config parameter

This refactoring reduces code duplication and follows the pattern established
in other causal LM model tests like Gemma.

* Apply style fixes

* Trigger tests

* Refactor DBRX test

* Make sure the DBRX-specific settings are handled

* Use the attribute_map

* Fix attribute map

---------

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Matt
2025-06-13 16:22:12 +01:00
committed by GitHub
parent 64041694a8
commit b82a45b3b4
2 changed files with 54 additions and 173 deletions

View File

@@ -181,11 +181,18 @@ class CausalLMModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@property
def config_args(self):
return list(signature(self.config_class.__init__).parameters.keys())
def get_config(self):
kwarg_names = list(signature(self.config_class.__init__).parameters.keys())
kwargs = {
k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self"
}
kwargs = {}
model_name_to_common_name = {v: k for k, v in self.config_class.attribute_map.items()}
for k in self.config_args + self.forced_config_args:
if hasattr(self, k) and k != "self":
kwargs[k] = getattr(self, k)
elif k in model_name_to_common_name and hasattr(self, model_name_to_common_name[k]):
kwargs[k] = getattr(self, model_name_to_common_name[k])
return self.config_class(**kwargs)
def create_and_check_model(