Add focalnet backbone (#23104)

Adds FocalNet backbone to return features from all stages
This commit is contained in:
Alara Dirik
2023-05-03 19:32:42 +03:00
committed by GitHub
parent ca7eb27ed5
commit 441658dd6c
10 changed files with 210 additions and 11 deletions

View File

@@ -22,6 +22,7 @@ from transformers import FocalNetConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
@@ -30,7 +31,12 @@ if is_torch_available():
import torch
from torch import nn
from transformers import FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel
from transformers import (
FocalNetBackbone,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetModel,
)
from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -48,6 +54,7 @@ class FocalNetModelTester:
patch_size=2,
num_channels=3,
embed_dim=16,
hidden_sizes=[32, 64, 128],
depths=[1, 2, 1],
num_heads=[2, 2, 4],
window_size=2,
@@ -67,6 +74,7 @@ class FocalNetModelTester:
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2"],
out_indices=[1, 2],
):
self.parent = parent
self.batch_size = batch_size
@@ -74,6 +82,7 @@ class FocalNetModelTester:
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.hidden_sizes = hidden_sizes
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
@@ -93,6 +102,7 @@ class FocalNetModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -111,6 +121,7 @@ class FocalNetModelTester:
patch_size=self.patch_size,
num_channels=self.num_channels,
embed_dim=self.embed_dim,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
num_heads=self.num_heads,
window_size=self.window_size,
@@ -126,6 +137,7 @@ class FocalNetModelTester:
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -139,6 +151,35 @@ class FocalNetModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size, 8, 8])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[:-1])
# verify backbone works with out_features=None
config.out_features = None
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size * 2, 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = FocalNetForMaskedImageModeling(config=config)
model.to(torch_device)
@@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
FocalNetModel,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetBackbone,
)
if is_torch_available()
else ()
@@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FocalNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37)
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
@@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
@@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="FocalNet Transformer does not use feedforward chunking")
@unittest.skip(reason="FocalNet does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
@@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
@@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
else (self.model_tester.image_size, self.model_tester.image_size)
)
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
@@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
@@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281)
@require_torch
class FocalNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (FocalNetBackbone,) if is_torch_available() else ()
config_class = FocalNetConfig
has_attentions = False
def setUp(self):
self.model_tester = FocalNetModelTester(self)

View File

@@ -135,6 +135,8 @@ class BackboneTesterMixin:
# Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
print(backbone.stage_names)
print(backbone.num_features)
self.assertTrue(len(backbone.stage_names) == len(backbone.num_features))
self.assertTrue(len(backbone.channels) <= len(backbone.num_features))
self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names))