Add BeitBackbone (#25952)

* First draft

* Add backwards compatibility

* More improvements

* More improvements

* Improve error message

* Address comment

* Add conversion script

* Fix style

* Update code snippet

* Adddress comment

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
NielsRogge
2023-11-28 09:38:32 +01:00
committed by GitHub
parent 7a757bb694
commit 1fb3c23b41
11 changed files with 619 additions and 39 deletions

View File

@@ -25,6 +25,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -35,7 +36,9 @@ if is_torch_available():
from torch import nn
from transformers import (
MODEL_FOR_BACKBONE_MAPPING,
MODEL_MAPPING,
BeitBackbone,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
@@ -63,7 +66,7 @@ class BeitModelTester:
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=2,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
@@ -73,10 +76,11 @@ class BeitModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
out_indices=[0, 1, 2, 3],
out_indices=[1, 2, 3, 4],
out_features=["stage1", "stage2", "stage3", "stage4"],
):
self.parent = parent
self.vocab_size = 100
self.vocab_size = vocab_size
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
@@ -94,6 +98,7 @@ class BeitModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.out_indices = out_indices
self.out_features = out_features
self.num_labels = num_labels
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
@@ -129,6 +134,7 @@ class BeitModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
out_indices=self.out_indices,
out_features=self.out_features,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
@@ -138,6 +144,38 @@ class BeitModelTester:
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_backbone(self, config, pixel_values, labels, pixel_labels):
model = BeitBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
expected_height = expected_width = self.image_size // config.patch_size
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
)
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = BeitBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
)
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config)
model.to(torch_device)
@@ -192,7 +230,13 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
all_model_classes = (
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
(
BeitModel,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitBackbone,
)
if is_torch_available()
else ()
)
@@ -226,6 +270,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="BEiT does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -239,6 +287,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
@@ -260,7 +312,11 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
# we don't test BeitForMaskedImageModeling
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
if model_class in [
*get_values(MODEL_MAPPING),
*get_values(MODEL_FOR_BACKBONE_MAPPING),
BeitForMaskedImageModeling,
]:
continue
model = model_class(config)
@@ -281,7 +337,8 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
# we don't test BeitForMaskedImageModeling
if (
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
model_class
in [*get_values(MODEL_MAPPING), *get_values(MODEL_FOR_BACKBONE_MAPPING), BeitForMaskedImageModeling]
or not model_class.supports_gradient_checkpointing
):
continue
@@ -487,3 +544,12 @@ class BeitModelIntegrationTest(unittest.TestCase):
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)
@require_torch
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (BeitBackbone,) if is_torch_available() else ()
config_class = BeitConfig
def setUp(self):
self.model_tester = BeitModelTester(self)