[NAT, DiNAT] Add backbone class (#20654)

* Add first draft

* Add out_features attribute to config

* Add corresponding test

* Add Dinat backbone

* Add BackboneMixin

* Add Backbone mixin, improve tests

* Fix embeddings

* Fix bug

* Improve backbones

* Fix Nat backbone tests

* Fix Dinat backbone tests

* Apply suggestions

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-12-13 17:06:59 +01:00
committed by GitHub
parent 30d8919ab1
commit 6ef42587ae
12 changed files with 449 additions and 60 deletions

View File

@@ -30,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import DinatForImageClassification, DinatModel
from transformers import DinatBackbone, DinatForImageClassification, DinatModel
from transformers.models.dinat.modeling_dinat import DINAT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -64,8 +64,8 @@ class DinatModelTester:
is_training=True,
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
num_labels=10,
out_features=["stage1", "stage2"],
):
self.parent = parent
self.batch_size = batch_size
@@ -89,15 +89,15 @@ class DinatModelTester:
self.is_training = is_training
self.scope = scope
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.num_labels = num_labels
self.out_features = out_features
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
@@ -105,6 +105,7 @@ class DinatModelTester:
def get_config(self):
return DinatConfig(
num_labels=self.num_labels,
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
@@ -122,7 +123,7 @@ class DinatModelTester:
patch_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -139,12 +140,11 @@ class DinatModelTester:
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DinatForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
# test greyscale images
config.num_channels = 1
@@ -154,7 +154,34 @@ class DinatModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = DinatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = DinatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -167,7 +194,15 @@ class DinatModelTester:
@require_torch
class DinatModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (DinatModel, DinatForImageClassification) if is_torch_available() else ()
all_model_classes = (
(
DinatModel,
DinatForImageClassification,
DinatBackbone,
)
if is_torch_available()
else ()
)
fx_compatible = False
test_torchscript = False
@@ -199,8 +234,16 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
@unittest.skip(reason="Dinat does not use inputs_embeds")
def test_inputs_embeds(self):
# Dinat does not use inputs_embeds
pass
@unittest.skip(reason="Dinat does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
@@ -257,17 +300,18 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
[height, width, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
if model_class.__name__ != "DinatBackbone":
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-3:]),
[height, width, self.model_tester.embed_dim],
)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-3:]),
[height, width, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()