Check models used for common tests are small (#24824)
* First models * Conditional DETR * Treat DETR models, skip others * Skip LayoutLMv2 as well * Fix last tests
This commit is contained in:
@@ -20,9 +20,16 @@ import math
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from transformers import DeformableDetrConfig, is_timm_available, is_vision_available
|
||||
from transformers import DeformableDetrConfig, ResNetConfig, is_torch_available, is_vision_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_timm, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -30,10 +37,10 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel, ResNetConfig
|
||||
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -49,7 +56,7 @@ class DeformableDetrModelTester:
|
||||
batch_size=8,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=256,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=4,
|
||||
@@ -116,6 +123,16 @@ class DeformableDetrModelTester:
|
||||
return config, pixel_values, pixel_mask, labels
|
||||
|
||||
def get_config(self):
|
||||
resnet_config = ResNetConfig(
|
||||
num_channels=3,
|
||||
embeddings_size=10,
|
||||
hidden_sizes=[10, 20, 30, 40],
|
||||
depths=[1, 1, 2, 1],
|
||||
hidden_act="relu",
|
||||
num_labels=3,
|
||||
out_features=["stage2", "stage3", "stage4"],
|
||||
out_indices=[2, 3, 4],
|
||||
)
|
||||
return DeformableDetrConfig(
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
@@ -131,6 +148,8 @@ class DeformableDetrModelTester:
|
||||
num_feature_levels=self.num_feature_levels,
|
||||
encoder_n_points=self.encoder_n_points,
|
||||
decoder_n_points=self.decoder_n_points,
|
||||
use_timm_backbone=False,
|
||||
backbone_config=resnet_config,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -165,32 +184,13 @@ class DeformableDetrModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
|
||||
config.use_timm_backbone = False
|
||||
config.backbone_config = ResNetConfig()
|
||||
model = DeformableDetrForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_timm
|
||||
@require_torch
|
||||
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_timm_available() else ()
|
||||
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
|
||||
if is_timm_available()
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
is_encoder_decoder = True
|
||||
@@ -246,10 +246,6 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
def test_deformable_detr_no_timm_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user