Add AutoBackbone + ResNetBackbone (#20229)

* Add ResNetBackbone

* Define channels and strides as property

* Remove file

* Add test for backbone

* Update BackboneOutput class

* Remove strides property

* Fix docstring

* Add backbones to SHOULD_HAVE_THEIR_OWN_PAGE

* Fix auto mapping name

* Add sanity check for out_features

* Set stage names based on depths

* Update to tuple

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-11-17 15:43:20 +01:00
committed by GitHub
parent 904ac21020
commit 6b217c52e6
10 changed files with 160 additions and 2 deletions

View File

@@ -30,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import ResNetForImageClassification, ResNetModel
from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel
from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -55,6 +55,7 @@ class ResNetModelTester:
hidden_act="relu",
num_labels=3,
scope=None,
out_features=["stage1", "stage2", "stage3", "stage4"],
):
self.parent = parent
self.batch_size = batch_size
@@ -69,6 +70,7 @@ class ResNetModelTester:
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
self.out_features = out_features
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -89,6 +91,7 @@ class ResNetModelTester:
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
out_features=self.out_features,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -110,6 +113,19 @@ class ResNetModelTester:
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])
# verify channels
self.parent.assertListEqual(model.channels, config.hidden_sizes)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
@@ -176,6 +192,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()