Update forward signature test for vision models (#27681)
* Update forward signature * Empty-Commit
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
""" Testing suite for the PyTorch MaskFormer model. """
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -266,18 +265,6 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ["facebook/maskformer-swin-small-coco"]:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
""" Testing suite for the PyTorch MaskFormer Swin model. """
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@@ -234,18 +233,6 @@ class MaskFormerSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@unittest.skip(reason="MaskFormerSwin is only used as backbone and doesn't support output_attentions")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user