[NAT, DiNAT] Add backbone class (#20654)
* Add first draft * Add out_features attribute to config * Add corresponding test * Add Dinat backbone * Add BackboneMixin * Add Backbone mixin, improve tests * Fix embeddings * Fix bug * Improve backbones * Fix Nat backbone tests * Fix Dinat backbone tests * Apply suggestions Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -30,7 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import NatForImageClassification, NatModel
|
||||
from transformers import NatBackbone, NatForImageClassification, NatModel
|
||||
from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
@@ -63,8 +63,8 @@ class NatModelTester:
|
||||
is_training=True,
|
||||
scope=None,
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
num_labels=10,
|
||||
out_features=["stage1", "stage2"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -87,15 +87,15 @@ class NatModelTester:
|
||||
self.is_training = is_training
|
||||
self.scope = scope
|
||||
self.use_labels = use_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
self.num_labels = num_labels
|
||||
self.out_features = out_features
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
@@ -103,6 +103,7 @@ class NatModelTester:
|
||||
|
||||
def get_config(self):
|
||||
return NatConfig(
|
||||
num_labels=self.num_labels,
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
@@ -119,7 +120,7 @@ class NatModelTester:
|
||||
patch_norm=self.patch_norm,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -136,12 +137,11 @@ class NatModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = NatForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
@@ -151,7 +151,34 @@ class NatModelTester:
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = NatBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = NatBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@@ -164,7 +191,15 @@ class NatModelTester:
|
||||
@require_torch
|
||||
class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (NatModel, NatForImageClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
NatModel,
|
||||
NatForImageClassification,
|
||||
NatBackbone,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = False
|
||||
|
||||
test_torchscript = False
|
||||
@@ -196,8 +231,16 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Nat does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
# Nat does not use inputs_embeds
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Nat does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
@@ -254,17 +297,18 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[height, width, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
if model_class.__name__ != "NatBackbone":
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-3:]),
|
||||
[height, width, self.model_tester.embed_dim],
|
||||
)
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-3:]),
|
||||
[height, width, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user