Add AutoBackbone + ResNetBackbone (#20229)
* Add ResNetBackbone * Define channels and strides as property * Remove file * Add test for backbone * Update BackboneOutput class * Remove strides property * Fix docstring * Add backbones to SHOULD_HAVE_THEIR_OWN_PAGE * Fix auto mapping name * Add sanity check for out_features * Set stage names based on depths * Update to tuple Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -30,7 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import ResNetForImageClassification, ResNetModel
|
||||
from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel
|
||||
from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ class ResNetModelTester:
|
||||
hidden_act="relu",
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -69,6 +70,7 @@ class ResNetModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.num_stages = len(hidden_sizes)
|
||||
self.out_features = out_features
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -89,6 +91,7 @@ class ResNetModelTester:
|
||||
depths=self.depths,
|
||||
hidden_act=self.hidden_act,
|
||||
num_labels=self.num_labels,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -110,6 +113,19 @@ class ResNetModelTester:
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = ResNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
@@ -176,6 +192,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user