made MaskFormerModelTest faster (#15942)
This commit is contained in:
committed by
GitHub
parent
e8efaecb87
commit
9932ee4b4b
@@ -20,7 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import MaskFormerConfig, is_torch_available, is_vision_available
|
||||
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
@@ -47,12 +47,12 @@ class MaskFormerModelTester:
|
||||
batch_size=2,
|
||||
is_training=True,
|
||||
use_auxiliary_loss=False,
|
||||
num_queries=100,
|
||||
num_queries=10,
|
||||
num_channels=3,
|
||||
min_size=384,
|
||||
max_size=640,
|
||||
num_labels=150,
|
||||
mask_feature_size=256,
|
||||
min_size=32 * 4,
|
||||
max_size=32 * 6,
|
||||
num_labels=4,
|
||||
mask_feature_size=32,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -79,11 +79,20 @@ class MaskFormerModelTester:
|
||||
return config, pixel_values, pixel_mask, mask_labels, class_labels
|
||||
|
||||
def get_config(self):
|
||||
return MaskFormerConfig(
|
||||
return MaskFormerConfig.from_backbone_and_decoder_configs(
|
||||
backbone_config=SwinConfig(
|
||||
depths=[1, 1, 1, 1],
|
||||
),
|
||||
decoder_config=DetrConfig(
|
||||
decoder_ffn_dim=128,
|
||||
num_queries=self.num_queries,
|
||||
decoder_attention_heads=2,
|
||||
d_model=self.mask_feature_size,
|
||||
),
|
||||
mask_feature_size=self.mask_feature_size,
|
||||
fpn_feature_size=self.mask_feature_size,
|
||||
num_channels=self.num_channels,
|
||||
num_labels=self.num_labels,
|
||||
mask_feature_size=self.mask_feature_size,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -161,7 +170,6 @@ class MaskFormerModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
||||
@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = MaskFormerModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_model_with_labels(self):
|
||||
size = (self.model_tester.min_size,) * 2
|
||||
inputs = {
|
||||
"pixel_values": torch.randn((2, 3, 384, 384)),
|
||||
"mask_labels": torch.randn((2, 10, 384, 384)),
|
||||
"pixel_values": torch.randn((2, 3, *size)),
|
||||
"mask_labels": torch.randn((2, 10, *size)),
|
||||
"class_labels": torch.zeros(2, 10).long(),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user