Add BeitBackbone (#25952)
* First draft * Add backwards compatibility * More improvements * More improvements * Improve error message * Address comment * Add conversion script * Fix style * Update code snippet * Adddress comment * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@@ -35,7 +36,9 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_BACKBONE_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
BeitBackbone,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
@@ -63,7 +66,7 @@ class BeitModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@@ -73,10 +76,11 @@ class BeitModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
out_indices=[0, 1, 2, 3],
|
||||
out_indices=[1, 2, 3, 4],
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.vocab_size = 100
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
@@ -94,6 +98,7 @@ class BeitModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.out_indices = out_indices
|
||||
self.out_features = out_features
|
||||
self.num_labels = num_labels
|
||||
|
||||
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
@@ -129,6 +134,7 @@ class BeitModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
out_indices=self.out_indices,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
|
||||
@@ -138,6 +144,38 @@ class BeitModelTester:
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
expected_height = expected_width = self.image_size // config.patch_size
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = BeitBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
|
||||
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitForMaskedImageModeling(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -192,7 +230,13 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
|
||||
(
|
||||
BeitModel,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitBackbone,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -226,6 +270,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BEiT does not support feedforward chunking yet")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -239,6 +287,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
@@ -260,7 +312,11 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||
if model_class in [
|
||||
*get_values(MODEL_MAPPING),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING),
|
||||
BeitForMaskedImageModeling,
|
||||
]:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
@@ -281,7 +337,8 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if (
|
||||
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
|
||||
model_class
|
||||
in [*get_values(MODEL_MAPPING), *get_values(MODEL_FOR_BACKBONE_MAPPING), BeitForMaskedImageModeling]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
continue
|
||||
@@ -487,3 +544,12 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((160, 160))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
all_model_classes = (BeitBackbone,) if is_torch_available() else ()
|
||||
config_class = BeitConfig
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BeitModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user