Enable more test_torchscript (#16679)
* update _create_and_check_torchscript * Enable test_torchscript * clear_class_registry Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -192,7 +192,6 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -219,7 +219,6 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
test_torchscript = False
|
|
||||||
test_mismatched_shapes = False
|
test_mismatched_shapes = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -151,7 +151,6 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|||||||
@@ -174,7 +174,6 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -372,7 +372,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Data2VecAudioModelTester(self)
|
self.model_tester = Data2VecAudioModelTester(self)
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class DecisionTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unit
|
|||||||
test_inputs_embeds = False
|
test_inputs_embeds = False
|
||||||
test_model_common_attributes = False
|
test_model_common_attributes = False
|
||||||
test_gradient_checkpointing = False
|
test_gradient_checkpointing = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = DecisionTransformerModelTester(self)
|
self.model_tester = DecisionTransformerModelTester(self)
|
||||||
|
|||||||
@@ -169,7 +169,6 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -211,7 +211,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_torchscript = True
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_resize_position_embeddings = True
|
test_resize_position_embeddings = True
|
||||||
|
|
||||||
|
|||||||
@@ -154,7 +154,6 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (DPTModel, DPTForDepthEstimation, DPTForSemanticSegmentation) if is_torch_available() else ()
|
all_model_classes = (DPTModel, DPTForDepthEstimation, DPTForSemanticSegmentation) if is_torch_available() else ()
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -284,7 +284,6 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Skip Tests
|
# Skip Tests
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,6 @@ class GLPNModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GLPNModelTester(self)
|
self.model_tester = GLPNModelTester(self)
|
||||||
|
|||||||
@@ -300,7 +300,6 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = HubertModelTester(self)
|
self.model_tester = HubertModelTester(self)
|
||||||
@@ -445,7 +444,6 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = HubertModelTester(
|
self.model_tester = HubertModelTester(
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
||||||
|
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
test_torchscript = False
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|||||||
@@ -205,7 +205,6 @@ class MPNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -890,7 +890,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
|
||||||
@@ -1100,7 +1099,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
|
|||||||
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
|
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
@@ -1128,7 +1127,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
|
|||||||
class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (ProphetNetEncoder,) if is_torch_available() else ()
|
all_model_classes = (ProphetNetEncoder,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,6 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|||||||
@@ -127,7 +127,6 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SegformerModelTester(self)
|
self.model_tester = SegformerModelTester(self)
|
||||||
|
|||||||
@@ -303,7 +303,6 @@ class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SEWModelTester(self)
|
self.model_tester = SEWModelTester(self)
|
||||||
|
|||||||
@@ -273,7 +273,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
|||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_torchscript = True
|
|
||||||
|
|
||||||
input_name = "input_features"
|
input_name = "input_features"
|
||||||
|
|
||||||
|
|||||||
@@ -229,7 +229,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -177,7 +177,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -512,7 +512,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_model_parallel = True
|
test_model_parallel = True
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
@@ -777,7 +776,6 @@ class T5EncoderOnlyModelTester:
|
|||||||
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_model_parallel = True
|
test_model_parallel = True
|
||||||
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
||||||
|
|||||||
@@ -422,7 +422,6 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -617,19 +617,21 @@ class ModelTesterMixin:
|
|||||||
model.eval()
|
model.eval()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
main_input_name = model_class.main_input_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
input_ids = inputs["input_ids"]
|
main_input = inputs[main_input_name]
|
||||||
attention_mask = inputs["attention_mask"]
|
attention_mask = inputs["attention_mask"]
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = inputs["input_ids"]
|
main_input = inputs[main_input_name]
|
||||||
traced_model = torch.jit.trace(model, input_ids)
|
traced_model = torch.jit.trace(model, main_input)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.fail("Couldn't trace module.")
|
self.fail("Couldn't trace module.")
|
||||||
|
|
||||||
|
|||||||
@@ -238,7 +238,6 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_mismatched_shapes = False
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
|
|||||||
@@ -305,7 +305,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = UniSpeechModelTester(
|
self.model_tester = UniSpeechModelTester(
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else ()
|
all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else ()
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|||||||
@@ -158,7 +158,6 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
|
|||||||
@@ -413,7 +413,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Wav2Vec2ModelTester(self)
|
self.model_tester = Wav2Vec2ModelTester(self)
|
||||||
@@ -652,7 +651,6 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Wav2Vec2ModelTester(
|
self.model_tester = Wav2Vec2ModelTester(
|
||||||
|
|||||||
@@ -316,7 +316,6 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = WavLMModelTester(self)
|
self.model_tester = WavLMModelTester(self)
|
||||||
|
|||||||
Reference in New Issue
Block a user