Add focalnet backbone (#23104)
Adds FocalNet backbone to return features from all stages
This commit is contained in:
@@ -22,6 +22,7 @@ from transformers import FocalNetConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
@@ -30,7 +31,12 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel
|
||||
from transformers import (
|
||||
FocalNetBackbone,
|
||||
FocalNetForImageClassification,
|
||||
FocalNetForMaskedImageModeling,
|
||||
FocalNetModel,
|
||||
)
|
||||
from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
@@ -48,6 +54,7 @@ class FocalNetModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
embed_dim=16,
|
||||
hidden_sizes=[32, 64, 128],
|
||||
depths=[1, 2, 1],
|
||||
num_heads=[2, 2, 4],
|
||||
window_size=2,
|
||||
@@ -67,6 +74,7 @@ class FocalNetModelTester:
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=8,
|
||||
out_features=["stage1", "stage2"],
|
||||
out_indices=[1, 2],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -74,6 +82,7 @@ class FocalNetModelTester:
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
@@ -93,6 +102,7 @@ class FocalNetModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -111,6 +121,7 @@ class FocalNetModelTester:
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
embed_dim=self.embed_dim,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
depths=self.depths,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.window_size,
|
||||
@@ -126,6 +137,7 @@ class FocalNetModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
out_features=self.out_features,
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -139,6 +151,35 @@ class FocalNetModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = FocalNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size, 8, 8])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[:-1])
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = FocalNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size * 2, 4, 4])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||
|
||||
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||
model = FocalNetForMaskedImageModeling(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
FocalNetModel,
|
||||
FocalNetForImageClassification,
|
||||
FocalNetForMaskedImageModeling,
|
||||
FocalNetBackbone,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FocalNetModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37)
|
||||
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
@@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_for_masked_image_modeling(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||
@@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="FocalNet Transformer does not use feedforward chunking")
|
||||
@unittest.skip(reason="FocalNet does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
@@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
@@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else (self.model_tester.image_size, self.model_tester.image_size)
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
|
||||
|
||||
@@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
|
||||
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
|
||||
|
||||
@@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281)
|
||||
|
||||
|
||||
@require_torch
|
||||
class FocalNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FocalNetBackbone,) if is_torch_available() else ()
|
||||
config_class = FocalNetConfig
|
||||
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FocalNetModelTester(self)
|
||||
|
||||
@@ -135,6 +135,8 @@ class BackboneTesterMixin:
|
||||
# Verify num_features has been initialized in the backbone init
|
||||
self.assertIsNotNone(backbone.num_features)
|
||||
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
|
||||
print(backbone.stage_names)
|
||||
print(backbone.num_features)
|
||||
self.assertTrue(len(backbone.stage_names) == len(backbone.num_features))
|
||||
self.assertTrue(len(backbone.channels) <= len(backbone.num_features))
|
||||
self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names))
|
||||
|
||||
Reference in New Issue
Block a user