From c04619ecf3ce1975d2716ff77e29e2875d9ddd60 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 11 Apr 2022 18:23:35 +0200 Subject: [PATCH] Enable more test_torchscript (#16679) * update _create_and_check_torchscript * Enable test_torchscript * clear_class_registry Co-authored-by: ydshieh --- tests/beit/test_modeling_beit.py | 1 - tests/canine/test_modeling_canine.py | 1 - tests/clip/test_modeling_clip.py | 1 - tests/convnext/test_modeling_convnext.py | 1 - tests/ctrl/test_modeling_ctrl.py | 1 - tests/data2vec/test_modeling_data2vec_audio.py | 1 - .../test_modeling_decision_transformer.py | 1 + tests/deit/test_modeling_deit.py | 1 - tests/distilbert/test_modeling_distilbert.py | 1 - tests/dpt/test_modeling_dpt.py | 1 - tests/fnet/test_modeling_fnet.py | 1 - tests/glpn/test_modeling_glpn.py | 1 - tests/hubert/test_modeling_hubert.py | 2 -- tests/maskformer/test_modeling_maskformer.py | 1 - tests/mpnet/test_modeling_mpnet.py | 1 - tests/prophetnet/test_modeling_prophetnet.py | 5 ++--- tests/regnet/test_modeling_regnet.py | 1 - tests/resnet/test_modeling_resnet.py | 1 - tests/segformer/test_modeling_segformer.py | 1 - tests/sew/test_modeling_sew.py | 1 - tests/speech_to_text/test_modeling_speech_to_text.py | 1 - tests/squeezebert/test_modeling_squeezebert.py | 1 - tests/swin/test_modeling_swin.py | 1 - tests/t5/test_modeling_t5.py | 2 -- tests/tapas/test_modeling_tapas.py | 1 - tests/test_modeling_common.py | 10 ++++++---- tests/transfo_xl/test_modeling_transfo_xl.py | 1 - tests/unispeech/test_modeling_unispeech.py | 1 - tests/van/test_modeling_van.py | 1 - tests/vit/test_modeling_vit.py | 1 - tests/wav2vec2/test_modeling_wav2vec2.py | 2 -- tests/wavlm/test_modeling_wavlm.py | 1 - 32 files changed, 9 insertions(+), 39 deletions(-) diff --git a/tests/beit/test_modeling_beit.py b/tests/beit/test_modeling_beit.py index fb6e7a40d4..5b4421b4d7 100644 --- a/tests/beit/test_modeling_beit.py +++ b/tests/beit/test_modeling_beit.py @@ -192,7 +192,6 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/canine/test_modeling_canine.py b/tests/canine/test_modeling_canine.py index dc873c86ac..5e3c37b37e 100644 --- a/tests/canine/test_modeling_canine.py +++ b/tests/canine/test_modeling_canine.py @@ -219,7 +219,6 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase): else () ) - test_torchscript = False test_mismatched_shapes = False test_resize_embeddings = False test_pruning = False diff --git a/tests/clip/test_modeling_clip.py b/tests/clip/test_modeling_clip.py index 09cc6cd61a..b9fc50c4b6 100644 --- a/tests/clip/test_modeling_clip.py +++ b/tests/clip/test_modeling_clip.py @@ -151,7 +151,6 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (CLIPVisionModel,) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/convnext/test_modeling_convnext.py b/tests/convnext/test_modeling_convnext.py index ab01d83907..68a42f38af 100644 --- a/tests/convnext/test_modeling_convnext.py +++ b/tests/convnext/test_modeling_convnext.py @@ -138,7 +138,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False has_attentions = False diff --git a/tests/ctrl/test_modeling_ctrl.py b/tests/ctrl/test_modeling_ctrl.py index 3daf31fd98..af754399b8 100644 --- a/tests/ctrl/test_modeling_ctrl.py +++ b/tests/ctrl/test_modeling_ctrl.py @@ -174,7 +174,6 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () test_pruning = True - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/data2vec/test_modeling_data2vec_audio.py b/tests/data2vec/test_modeling_data2vec_audio.py index e5b8fd0e3a..ecadcb5903 100644 --- a/tests/data2vec/test_modeling_data2vec_audio.py +++ b/tests/data2vec/test_modeling_data2vec_audio.py @@ -372,7 +372,6 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = Data2VecAudioModelTester(self) diff --git a/tests/decision_transformer/test_modeling_decision_transformer.py b/tests/decision_transformer/test_modeling_decision_transformer.py index 41d86b4314..0843ce630e 100644 --- a/tests/decision_transformer/test_modeling_decision_transformer.py +++ b/tests/decision_transformer/test_modeling_decision_transformer.py @@ -148,6 +148,7 @@ class DecisionTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unit test_inputs_embeds = False test_model_common_attributes = False test_gradient_checkpointing = False + test_torchscript = False def setUp(self): self.model_tester = DecisionTransformerModelTester(self) diff --git a/tests/deit/test_modeling_deit.py b/tests/deit/test_modeling_deit.py index 7558b71565..f0d97c1369 100644 --- a/tests/deit/test_modeling_deit.py +++ b/tests/deit/test_modeling_deit.py @@ -169,7 +169,6 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/distilbert/test_modeling_distilbert.py b/tests/distilbert/test_modeling_distilbert.py index 535ce2604d..2dfd31ac06 100644 --- a/tests/distilbert/test_modeling_distilbert.py +++ b/tests/distilbert/test_modeling_distilbert.py @@ -211,7 +211,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ) fx_compatible = True test_pruning = True - test_torchscript = True test_resize_embeddings = True test_resize_position_embeddings = True diff --git a/tests/dpt/test_modeling_dpt.py b/tests/dpt/test_modeling_dpt.py index 93601a5723..aaa0c66f2e 100644 --- a/tests/dpt/test_modeling_dpt.py +++ b/tests/dpt/test_modeling_dpt.py @@ -154,7 +154,6 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (DPTModel, DPTForDepthEstimation, DPTForSemanticSegmentation) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/fnet/test_modeling_fnet.py b/tests/fnet/test_modeling_fnet.py index 5ab5c4a57c..b175bf3a54 100644 --- a/tests/fnet/test_modeling_fnet.py +++ b/tests/fnet/test_modeling_fnet.py @@ -284,7 +284,6 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase): # Skip Tests test_pruning = False - test_torchscript = False test_head_masking = False test_pruning = False diff --git a/tests/glpn/test_modeling_glpn.py b/tests/glpn/test_modeling_glpn.py index 1be6b9bf5c..323215d78b 100644 --- a/tests/glpn/test_modeling_glpn.py +++ b/tests/glpn/test_modeling_glpn.py @@ -150,7 +150,6 @@ class GLPNModelTest(ModelTesterMixin, unittest.TestCase): test_head_masking = False test_pruning = False test_resize_embeddings = False - test_torchscript = False def setUp(self): self.model_tester = GLPNModelTester(self) diff --git a/tests/hubert/test_modeling_hubert.py b/tests/hubert/test_modeling_hubert.py index 3177aa3d02..3c3ba96c99 100644 --- a/tests/hubert/test_modeling_hubert.py +++ b/tests/hubert/test_modeling_hubert.py @@ -300,7 +300,6 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = HubertModelTester(self) @@ -445,7 +444,6 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else () test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = HubertModelTester( diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py index 50dbecb8de..43daf85ab1 100644 --- a/tests/maskformer/test_modeling_maskformer.py +++ b/tests/maskformer/test_modeling_maskformer.py @@ -177,7 +177,6 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else () is_encoder_decoder = False - test_torchscript = False test_pruning = False test_head_masking = False test_missing_keys = False diff --git a/tests/mpnet/test_modeling_mpnet.py b/tests/mpnet/test_modeling_mpnet.py index 6869a91c7a..5417313998 100644 --- a/tests/mpnet/test_modeling_mpnet.py +++ b/tests/mpnet/test_modeling_mpnet.py @@ -205,7 +205,6 @@ class MPNetModelTest(ModelTesterMixin, unittest.TestCase): else () ) test_pruning = False - test_torchscript = True test_resize_embeddings = True def setUp(self): diff --git a/tests/prophetnet/test_modeling_prophetnet.py b/tests/prophetnet/test_modeling_prophetnet.py index b8c5a8d8ec..17bf4523e7 100644 --- a/tests/prophetnet/test_modeling_prophetnet.py +++ b/tests/prophetnet/test_modeling_prophetnet.py @@ -890,7 +890,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False is_encoder_decoder = True @@ -1100,7 +1099,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else () all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else () test_pruning = False - test_torchscript = False + test_resize_embeddings = False is_encoder_decoder = False @@ -1128,7 +1127,7 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ProphetNetEncoder,) if is_torch_available() else () test_pruning = False - test_torchscript = False + test_resize_embeddings = False is_encoder_decoder = False diff --git a/tests/regnet/test_modeling_regnet.py b/tests/regnet/test_modeling_regnet.py index 331e45296b..2660108e96 100644 --- a/tests/regnet/test_modeling_regnet.py +++ b/tests/regnet/test_modeling_regnet.py @@ -127,7 +127,6 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False has_attentions = False diff --git a/tests/resnet/test_modeling_resnet.py b/tests/resnet/test_modeling_resnet.py index 5880caa5e4..7a0d1ee473 100644 --- a/tests/resnet/test_modeling_resnet.py +++ b/tests/resnet/test_modeling_resnet.py @@ -127,7 +127,6 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False has_attentions = False diff --git a/tests/segformer/test_modeling_segformer.py b/tests/segformer/test_modeling_segformer.py index 8798a823ea..3c3f6ee5b4 100644 --- a/tests/segformer/test_modeling_segformer.py +++ b/tests/segformer/test_modeling_segformer.py @@ -165,7 +165,6 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): test_head_masking = False test_pruning = False test_resize_embeddings = False - test_torchscript = False def setUp(self): self.model_tester = SegformerModelTester(self) diff --git a/tests/sew/test_modeling_sew.py b/tests/sew/test_modeling_sew.py index d3d44b52e1..e8b06610df 100644 --- a/tests/sew/test_modeling_sew.py +++ b/tests/sew/test_modeling_sew.py @@ -303,7 +303,6 @@ class SEWModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else () test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = SEWModelTester(self) diff --git a/tests/speech_to_text/test_modeling_speech_to_text.py b/tests/speech_to_text/test_modeling_speech_to_text.py index 03f2d81166..82b1c74c59 100644 --- a/tests/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/speech_to_text/test_modeling_speech_to_text.py @@ -273,7 +273,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes is_encoder_decoder = True test_pruning = False test_missing_keys = False - test_torchscript = True input_name = "input_features" diff --git a/tests/squeezebert/test_modeling_squeezebert.py b/tests/squeezebert/test_modeling_squeezebert.py index d1d446499a..c728aa2b0c 100644 --- a/tests/squeezebert/test_modeling_squeezebert.py +++ b/tests/squeezebert/test_modeling_squeezebert.py @@ -229,7 +229,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): else None ) test_pruning = False - test_torchscript = True test_resize_embeddings = True test_head_masking = False diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 414335ea4e..2147f578e7 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -177,7 +177,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/t5/test_modeling_t5.py b/tests/t5/test_modeling_t5.py index e4bdf8ad30..8380484b06 100644 --- a/tests/t5/test_modeling_t5.py +++ b/tests/t5/test_modeling_t5.py @@ -512,7 +512,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): fx_compatible = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False - test_torchscript = True test_resize_embeddings = True test_model_parallel = True is_encoder_decoder = True @@ -777,7 +776,6 @@ class T5EncoderOnlyModelTester: class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (T5EncoderModel,) if is_torch_available() else () test_pruning = False - test_torchscript = True test_resize_embeddings = False test_model_parallel = True all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else () diff --git a/tests/tapas/test_modeling_tapas.py b/tests/tapas/test_modeling_tapas.py index 84fb0b4b3e..385af04ded 100644 --- a/tests/tapas/test_modeling_tapas.py +++ b/tests/tapas/test_modeling_tapas.py @@ -422,7 +422,6 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): else None ) test_pruning = False - test_torchscript = False test_resize_embeddings = True test_head_masking = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4b54ec45d5..db24ece11f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -617,19 +617,21 @@ class ModelTesterMixin: model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) + main_input_name = model_class.main_input_name + try: if model.config.is_encoder_decoder: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - input_ids = inputs["input_ids"] + main_input = inputs[main_input_name] attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] traced_model = torch.jit.trace( - model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) + model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) ) else: - input_ids = inputs["input_ids"] - traced_model = torch.jit.trace(model, input_ids) + main_input = inputs[main_input_name] + traced_model = torch.jit.trace(model, main_input) except RuntimeError: self.fail("Couldn't trace module.") diff --git a/tests/transfo_xl/test_modeling_transfo_xl.py b/tests/transfo_xl/test_modeling_transfo_xl.py index 12098c5185..d4dbba448a 100644 --- a/tests/transfo_xl/test_modeling_transfo_xl.py +++ b/tests/transfo_xl/test_modeling_transfo_xl.py @@ -238,7 +238,6 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ) all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = True test_mismatched_shapes = False diff --git a/tests/unispeech/test_modeling_unispeech.py b/tests/unispeech/test_modeling_unispeech.py index b118120e58..9a25237bf3 100644 --- a/tests/unispeech/test_modeling_unispeech.py +++ b/tests/unispeech/test_modeling_unispeech.py @@ -305,7 +305,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = UniSpeechModelTester( diff --git a/tests/van/test_modeling_van.py b/tests/van/test_modeling_van.py index fe21c83c20..dff60fea38 100644 --- a/tests/van/test_modeling_van.py +++ b/tests/van/test_modeling_van.py @@ -124,7 +124,6 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else () test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False has_attentions = False diff --git a/tests/vit/test_modeling_vit.py b/tests/vit/test_modeling_vit.py index a6f167b48b..db304aa815 100644 --- a/tests/vit/test_modeling_vit.py +++ b/tests/vit/test_modeling_vit.py @@ -158,7 +158,6 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False - test_torchscript = False test_resize_embeddings = False test_head_masking = False diff --git a/tests/wav2vec2/test_modeling_wav2vec2.py b/tests/wav2vec2/test_modeling_wav2vec2.py index 126882bf57..d6a2e7e67a 100644 --- a/tests/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/wav2vec2/test_modeling_wav2vec2.py @@ -413,7 +413,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = Wav2Vec2ModelTester(self) @@ -652,7 +651,6 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = Wav2Vec2ModelTester( diff --git a/tests/wavlm/test_modeling_wavlm.py b/tests/wavlm/test_modeling_wavlm.py index 7687dc3936..937325e721 100644 --- a/tests/wavlm/test_modeling_wavlm.py +++ b/tests/wavlm/test_modeling_wavlm.py @@ -316,7 +316,6 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase): ) test_pruning = False test_headmasking = False - test_torchscript = False def setUp(self): self.model_tester = WavLMModelTester(self)