[DINOv2] Add backbone class (#25520)

* First draft

* More improvements

* Fix all tests

* More improvements

* Add backbone test

* Improve docstring

* Address comments

* Rename attribute

* Remove expected output

* Update src/transformers/models/dinov2/modeling_dinov2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix style

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
NielsRogge
2023-08-29 12:05:27 +02:00
committed by GitHub
parent 4c21da5e34
commit 77713d11f6
8 changed files with 222 additions and 9 deletions

View File

@@ -27,6 +27,7 @@ from transformers.testing_utils import (
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -36,7 +37,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Dinov2ForImageClassification, Dinov2Model
from transformers import Dinov2Backbone, Dinov2ForImageClassification, Dinov2Model
from transformers.models.dinov2.modeling_dinov2 import DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -123,6 +124,53 @@ class Dinov2ModelTester:
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_backbone(self, config, pixel_values, labels):
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
expected_size = self.image_size // config.patch_size
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
)
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
)
# verify channels
self.parent.assertEqual(len(model.channels), 1)
# verify backbone works with apply_layernorm=False and reshape_hidden_states=False
config.apply_layernorm = False
config.reshape_hidden_states = False
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = Dinov2ForImageClassification(config)
@@ -159,7 +207,15 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
all_model_classes = (Dinov2Model, Dinov2ForImageClassification) if is_torch_available() else ()
all_model_classes = (
(
Dinov2Model,
Dinov2ForImageClassification,
Dinov2Backbone,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available()
@@ -207,10 +263,18 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="Dinov2 does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -252,3 +316,14 @@ class Dinov2ModelIntegrationTest(unittest.TestCase):
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@require_torch
class Dinov2BackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (Dinov2Backbone,) if is_torch_available() else ()
config_class = Dinov2Config
has_attentions = False
def setUp(self):
self.model_tester = Dinov2ModelTester(self)