made MaskFormerModelTest faster (#15942)
This commit is contained in:
committed by
GitHub
parent
e8efaecb87
commit
9932ee4b4b
@@ -20,7 +20,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import MaskFormerConfig, is_torch_available, is_vision_available
|
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
|
|
||||||
@@ -47,12 +47,12 @@ class MaskFormerModelTester:
|
|||||||
batch_size=2,
|
batch_size=2,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_auxiliary_loss=False,
|
use_auxiliary_loss=False,
|
||||||
num_queries=100,
|
num_queries=10,
|
||||||
num_channels=3,
|
num_channels=3,
|
||||||
min_size=384,
|
min_size=32 * 4,
|
||||||
max_size=640,
|
max_size=32 * 6,
|
||||||
num_labels=150,
|
num_labels=4,
|
||||||
mask_feature_size=256,
|
mask_feature_size=32,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -79,11 +79,20 @@ class MaskFormerModelTester:
|
|||||||
return config, pixel_values, pixel_mask, mask_labels, class_labels
|
return config, pixel_values, pixel_mask, mask_labels, class_labels
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return MaskFormerConfig(
|
return MaskFormerConfig.from_backbone_and_decoder_configs(
|
||||||
|
backbone_config=SwinConfig(
|
||||||
|
depths=[1, 1, 1, 1],
|
||||||
|
),
|
||||||
|
decoder_config=DetrConfig(
|
||||||
|
decoder_ffn_dim=128,
|
||||||
num_queries=self.num_queries,
|
num_queries=self.num_queries,
|
||||||
|
decoder_attention_heads=2,
|
||||||
|
d_model=self.mask_feature_size,
|
||||||
|
),
|
||||||
|
mask_feature_size=self.mask_feature_size,
|
||||||
|
fpn_feature_size=self.mask_feature_size,
|
||||||
num_channels=self.num_channels,
|
num_channels=self.num_channels,
|
||||||
num_labels=self.num_labels,
|
num_labels=self.num_labels,
|
||||||
mask_feature_size=self.mask_feature_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
@@ -161,7 +170,6 @@ class MaskFormerModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
|
||||||
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
||||||
@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = MaskFormerModel.from_pretrained(model_name)
|
model = MaskFormerModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_model_with_labels(self):
|
def test_model_with_labels(self):
|
||||||
|
size = (self.model_tester.min_size,) * 2
|
||||||
inputs = {
|
inputs = {
|
||||||
"pixel_values": torch.randn((2, 3, 384, 384)),
|
"pixel_values": torch.randn((2, 3, *size)),
|
||||||
"mask_labels": torch.randn((2, 10, 384, 384)),
|
"mask_labels": torch.randn((2, 10, *size)),
|
||||||
"class_labels": torch.zeros(2, 10).long(),
|
"class_labels": torch.zeros(2, 10).long(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user