Backbone add mixin tests (#22542)
* Add out_indices to backbones, deprecate out_features * Update - can specify both out_features and out_indices but not both * Add backbone mixin tests * Test tidy up * Add test_backbone for convnext * Remove redefinition of method * Update for Dinat and Nat backbones * Update tests * Smarter indexing * Add checks on config creation for backbone * PR comments
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import SwinConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@@ -69,6 +70,7 @@ class SwinModelTester:
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
out_features=["stage1", "stage2"],
|
||||
out_indices=[1, 2],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -95,6 +97,7 @@ class SwinModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -128,6 +131,7 @@ class SwinModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
out_features=self.out_features,
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -502,3 +506,12 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
all_model_classes = (SwinBackbone,) if is_torch_available() else ()
|
||||
config_class = SwinConfig
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SwinModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user