Fix tests for vision models (#35654)

* Trigger tests

* [run-slow] beit, detr, dinov2, vit, textnet

* Fix BEiT interpolate_pos_encoding

* Fix DETR test

* Update DINOv2 test

* Fix textnet

* Fix vit

* Fix DPT

* fix data2vec test

* Fix textnet test

* Update interpolation check

* Fix ZoeDepth tests

* Update interpolate embeddings for BEiT

* Apply suggestions from code review
This commit is contained in:
Pavel Iakubovskii
2025-02-13 10:28:37 +00:00
committed by GitHub
parent e60ae0d078
commit d419862889
9 changed files with 55 additions and 79 deletions

View File

@@ -16,6 +16,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -196,12 +197,16 @@ class BeitEmbeddings(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.position_embeddings is not None and interpolate_pos_encoding is not None:
warnings.warn(
"`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
"interpolated to the input image size. The argument will be removed in transformers v4.51.0."
)
_, _, height, width = pixel_values.shape _, _, height, width = pixel_values.shape
embeddings, (patch_height, patch_width) = self.patch_embeddings( embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
@@ -211,14 +216,11 @@ class BeitEmbeddings(nn.Module):
embeddings = embeddings * (1 - w) + mask_tokens * w embeddings = embeddings * (1 - w) + mask_tokens * w
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
if interpolate_pos_encoding:
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
embeddings = torch.cat((cls_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if self.position_embeddings is not None:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings, (patch_height, patch_width) return embeddings, (patch_height, patch_width)
@@ -248,11 +250,7 @@ class BeitPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward( def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -261,17 +259,6 @@ class BeitPatchEmbeddings(nn.Module):
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
0, 3, 1, 2
)
position_embedding = nn.functional.interpolate(
position_embedding, size=(patch_height, patch_width), mode="bicubic"
)
embeddings = embeddings + position_embedding
embeddings = embeddings.flatten(2).transpose(1, 2) embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (patch_height, patch_width) return embeddings, (patch_height, patch_width)
@@ -887,9 +874,7 @@ class BeitModel(BeitPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, _ = self.embeddings( embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
resolution = pixel_values.shape[2:] resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder( encoder_outputs = self.encoder(

View File

@@ -16,6 +16,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -195,12 +196,16 @@ class Data2VecVisionEmbeddings(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.position_embeddings is not None and interpolate_pos_encoding is not None:
warnings.warn(
"`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
"interpolated to the input image size. The argument will be removed in transformers v4.51.0."
)
_, _, height, width = pixel_values.shape _, _, height, width = pixel_values.shape
embeddings, (patch_height, patch_width) = self.patch_embeddings( embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
@@ -210,14 +215,11 @@ class Data2VecVisionEmbeddings(nn.Module):
embeddings = embeddings * (1 - w) + mask_tokens * w embeddings = embeddings * (1 - w) + mask_tokens * w
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
if interpolate_pos_encoding:
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
embeddings = torch.cat((cls_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, embeddings), dim=1)
if self.position_embeddings is not None:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings, (patch_height, patch_width) return embeddings, (patch_height, patch_width)
@@ -248,11 +250,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward( def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
@@ -261,17 +259,6 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
0, 3, 1, 2
)
position_embedding = nn.functional.interpolate(
position_embedding, size=(patch_height, patch_width), mode="bicubic"
)
embeddings = embeddings + position_embedding
embeddings = embeddings.flatten(2).transpose(1, 2) embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, (patch_height, patch_width) return embeddings, (patch_height, patch_width)
@@ -902,9 +889,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, _ = self.embeddings( embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
resolution = pixel_values.shape[2:] resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder( encoder_outputs = self.encoder(

View File

@@ -774,7 +774,9 @@ class BeitModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True) outputs = model(pixel_values, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1801, 768)) # num_cls_tokens + (height / patch_size) * (width / patch_size)
# 1 + (480 / 16) * (480 / 16) = 1 + 30 * 30 = 901
expected_shape = torch.Size((1, 901, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape) self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

View File

@@ -565,17 +565,12 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
pixel_values = inputs.pixel_values.to(torch_device) pixel_values = inputs.pixel_values.to(torch_device)
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
self.assertFalse(processor.do_center_crop)
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
# with interpolate_pos_encoding being True the model should process the higher resolution image # with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output. # successfully and produce the expected output.
with torch.no_grad(): with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True) outputs = model(pixel_values, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1801, 768)) # num_cls_tokens + (height / patch_size) * (width / patch_size)
# 1 + (480 / 16) * (480 / 16) = 901
expected_shape = torch.Size((1, 901, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape) self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

View File

@@ -684,7 +684,12 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
self.assertTrue(results["segmentation"].shape, expected_shape) self.assertTrue(results["segmentation"].shape, expected_shape)
torch.testing.assert_close(results["segmentation"][:3, :3], expected_slice_segmentation, rtol=1e-4, atol=1e-4) torch.testing.assert_close(results["segmentation"][:3, :3], expected_slice_segmentation, rtol=1e-4, atol=1e-4)
self.assertTrue(len(results["segments_info"]), expected_number_of_segments) self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
self.assertDictEqual(results["segments_info"][0], expected_first_segment)
predicted_first_segment = results["segments_info"][0]
self.assertEqual(predicted_first_segment["id"], expected_first_segment["id"])
self.assertEqual(predicted_first_segment["label_id"], expected_first_segment["label_id"])
self.assertEqual(predicted_first_segment["was_fused"], expected_first_segment["was_fused"])
self.assertAlmostEqual(predicted_first_segment["score"], expected_first_segment["score"], places=3)
@require_vision @require_vision

View File

@@ -329,10 +329,10 @@ class Dinov2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.last_hidden_state.shape, expected_shape) self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-2.1747, -0.4729, 1.0936], [-3.2780, -0.8269, -0.9210], [-2.9129, 1.1284, -0.7306]], [[-2.2005, -0.4495, 1.0964], [-3.3959, -0.8942, -1.0315], [-2.9355, 1.1564, -0.7656]],
device=torch_device, device=torch_device,
) )
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
@require_torch @require_torch

View File

@@ -328,14 +328,18 @@ class TextNetModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
output = model(**inputs) output = model(**inputs)
# verify logits # verify output
self.assertEqual(output.logits.shape, torch.Size([1, 2])) self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 512, 20, 27]))
expected_slice_backbone = torch.tensor( expected_slice_backbone = torch.tensor(
[0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000], [
[0.0000, 1.7415, 1.2660],
[0.0000, 1.0084, 1.9692],
[0.0000, 1.7464, 1.7892],
],
device=torch_device, device=torch_device,
) )
torch.testing.assert_close( torch.testing.assert_close(
output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, rtol=1e-3, atol=1e-3 output.last_hidden_state[0, 12, :3, :3], expected_slice_backbone, rtol=1e-2, atol=1e-2
) )

View File

@@ -310,10 +310,10 @@ class ViTModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.last_hidden_state.shape, expected_shape) self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]] [[4.2325, 4.3882, -6.6678], [4.5372, 1.8933, -6.7355], [4.4454, 0.8514, -5.8747]]
).to(torch_device) ).to(torch_device)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
@slow @slow
@require_accelerate @require_accelerate

View File

@@ -301,8 +301,8 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase):
out_l_reduced = torch.nn.functional.interpolate( out_l_reduced = torch.nn.functional.interpolate(
out_l.unsqueeze(0).unsqueeze(1), size=img.size[::-1], mode="bicubic", align_corners=False out_l.unsqueeze(0).unsqueeze(1), size=img.size[::-1], mode="bicubic", align_corners=False
) )
self.assertTrue((np.array(out_l.shape)[::-1] == np.array(img.size) * 2).all()) out_l_reduced = out_l_reduced.squeeze(0).squeeze(0)
torch.testing.assert_close(out, out_l_reduced, rtol=2e-2) torch.testing.assert_close(out, out_l_reduced, rtol=2e-2, atol=2e-2)
def check_post_processing_test(self, image_processor, images, model, pad_input=True, flip_aug=True): def check_post_processing_test(self, image_processor, images, model, pad_input=True, flip_aug=True):
inputs = image_processor(images=images, return_tensors="pt", do_pad=pad_input).to(torch_device) inputs = image_processor(images=images, return_tensors="pt", do_pad=pad_input).to(torch_device)
@@ -324,7 +324,7 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase):
for img, out, expected_slice in zip(images, outputs, expected_slices): for img, out, expected_slice in zip(images, outputs, expected_slices):
out = out["predicted_depth"] out = out["predicted_depth"]
self.assertTrue(img.size == out.shape[::-1]) self.assertTrue(img.size == out.shape[::-1])
torch.testing.assert_close(expected_slice, out[:3, :3], atol=1e-3, rtol=1e-3) torch.testing.assert_close(expected_slice, out[:3, :3], rtol=1e-3, atol=1e-3)
self.check_target_size(image_processor, pad_input, images, outputs, raw_outputs, raw_outputs_flipped) self.check_target_size(image_processor, pad_input, images, outputs, raw_outputs, raw_outputs_flipped)