From 9932ee4b4bca9045d941af6687ef69eedcf68483 Mon Sep 17 00:00:00 2001 From: Francesco Saverio Zuppichini Date: Fri, 4 Mar 2022 19:11:48 +0100 Subject: [PATCH] made MaskFormerModelTest faster (#15942) --- tests/maskformer/test_modeling_maskformer.py | 34 ++++++++++++-------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py index f2e1f56f0f..3f885b3874 100644 --- a/tests/maskformer/test_modeling_maskformer.py +++ b/tests/maskformer/test_modeling_maskformer.py @@ -20,7 +20,7 @@ import unittest import numpy as np from tests.test_modeling_common import floats_tensor -from transformers import MaskFormerConfig, is_torch_available, is_vision_available +from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, require_vision, slow, torch_device @@ -47,12 +47,12 @@ class MaskFormerModelTester: batch_size=2, is_training=True, use_auxiliary_loss=False, - num_queries=100, + num_queries=10, num_channels=3, - min_size=384, - max_size=640, - num_labels=150, - mask_feature_size=256, + min_size=32 * 4, + max_size=32 * 6, + num_labels=4, + mask_feature_size=32, ): self.parent = parent self.batch_size = batch_size @@ -79,11 +79,20 @@ class MaskFormerModelTester: return config, pixel_values, pixel_mask, mask_labels, class_labels def get_config(self): - return MaskFormerConfig( - num_queries=self.num_queries, + return MaskFormerConfig.from_backbone_and_decoder_configs( + backbone_config=SwinConfig( + depths=[1, 1, 1, 1], + ), + decoder_config=DetrConfig( + decoder_ffn_dim=128, + num_queries=self.num_queries, + decoder_attention_heads=2, + d_model=self.mask_feature_size, + ), + mask_feature_size=self.mask_feature_size, + fpn_feature_size=self.mask_feature_size, num_channels=self.num_channels, num_labels=self.num_labels, - mask_feature_size=self.mask_feature_size, ) def prepare_config_and_inputs_for_common(self): @@ -161,7 +170,6 @@ class MaskFormerModelTester: @require_torch -@slow class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else () @@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): model = MaskFormerModel.from_pretrained(model_name) self.assertIsNotNone(model) - @slow def test_model_with_labels(self): + size = (self.model_tester.min_size,) * 2 inputs = { - "pixel_values": torch.randn((2, 3, 384, 384)), - "mask_labels": torch.randn((2, 10, 384, 384)), + "pixel_values": torch.randn((2, 3, *size)), + "mask_labels": torch.randn((2, 10, *size)), "class_labels": torch.zeros(2, 10).long(), }