Use self.config_tester.run_common_tests() (#31431)

* First testing updating config tests

* Use run_common_tests
This commit is contained in:
amyeroberts
2024-06-19 10:18:08 +01:00
committed by GitHub
parent 7c71b61dae
commit 609e662243
28 changed files with 174 additions and 187 deletions

View File

@@ -18,7 +18,6 @@ import inspect
import unittest
from transformers import AutoBackbone
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import require_timm, require_torch, torch_device
from transformers.utils.import_utils import is_torch_available
@@ -106,17 +105,15 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
has_attentions = False
def setUp(self):
self.config_class = PretrainedConfig
# self.config_class = PretrainedConfig
self.config_class = TimmBackboneConfig
self.model_tester = TimmBackboneModelTester(self)
self.config_tester = ConfigTester(self, config_class=self.config_class, has_text_modality=False)
self.config_tester = ConfigTester(
self, config_class=self.config_class, has_text_modality=False, common_properties=["num_channels"]
)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
self.config_tester.run_common_tests()
def test_timm_transformer_backbone_equivalence(self):
timm_checkpoint = "resnet18"