From fa01127a677103f359597656c9d995d92b517f71 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 6 Apr 2023 17:56:06 +0200 Subject: [PATCH] update_pip_test_mapping (#22606) * Add TFBlipForConditionalGeneration * update pipeline_model_mapping * Add import * Revert changes in GPTSanJapaneseTest --------- Co-authored-by: ydshieh --- src/transformers/models/auto/modeling_tf_auto.py | 1 + tests/models/align/test_modeling_align.py | 4 +++- tests/models/bart/test_modeling_bart.py | 3 ++- tests/models/bart/test_modeling_tf_bart.py | 3 ++- .../bigbird_pegasus/test_modeling_bigbird_pegasus.py | 3 ++- tests/models/blenderbot/test_modeling_blenderbot.py | 3 ++- tests/models/blenderbot/test_modeling_tf_blenderbot.py | 1 + .../blenderbot_small/test_modeling_blenderbot_small.py | 3 ++- .../blenderbot_small/test_modeling_tf_blenderbot_small.py | 1 + tests/models/convnextv2/test_modeling_convnextv2.py | 8 +++++++- tests/models/efficientnet/test_modeling_efficientnet.py | 8 +++++++- tests/models/fsmt/test_modeling_fsmt.py | 1 + .../gptsan_japanese/test_modeling_gptsan_japanese.py | 2 +- tests/models/informer/test_modeling_informer.py | 4 +++- tests/models/led/test_modeling_led.py | 3 ++- tests/models/led/test_modeling_tf_led.py | 1 + tests/models/longt5/test_modeling_longt5.py | 1 + tests/models/m2m_100/test_modeling_m2m_100.py | 1 + tests/models/marian/test_modeling_marian.py | 3 ++- tests/models/marian/test_modeling_tf_marian.py | 1 + tests/models/mbart/test_modeling_mbart.py | 3 ++- tests/models/mbart/test_modeling_tf_mbart.py | 1 + tests/models/mega/test_modeling_mega.py | 2 ++ tests/models/mgp_str/test_modeling_mgp_str.py | 4 +++- tests/models/mvp/test_modeling_mvp.py | 3 ++- tests/models/nllb_moe/test_modeling_nllb_moe.py | 1 + tests/models/pegasus/test_modeling_pegasus.py | 3 ++- tests/models/pegasus/test_modeling_tf_pegasus.py | 1 + tests/models/pegasus_x/test_modeling_pegasus_x.py | 1 + tests/models/plbart/test_modeling_plbart.py | 3 ++- tests/models/prophetnet/test_modeling_prophetnet.py | 3 ++- .../test_modeling_switch_transformers.py | 1 + tests/models/t5/test_modeling_t5.py | 1 + tests/models/t5/test_modeling_tf_t5.py | 1 + tests/models/whisper/test_modeling_whisper.py | 6 +++++- 35 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index d7e5be017b..8d7d72711e 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -231,6 +231,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( [ + ("blip", "TFBlipForConditionalGeneration"), ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), ] ) diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index f2b1b1efda..02d046924c 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -40,6 +40,7 @@ from ...test_modeling_common import ( ids_tensor, random_attention_mask, ) +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -420,8 +421,9 @@ class AlignModelTester: @require_torch -class AlignModelTest(ModelTesterMixin, unittest.TestCase): +class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (AlignModel,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": AlignModel} if is_torch_available() else {} fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index d5c2094c15..36837c9556 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -429,9 +429,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin "fill-mask": BartForConditionalGeneration, "question-answering": BartForQuestionAnswering, "summarization": BartForConditionalGeneration, - "text2text-generation": BartForConditionalGeneration, "text-classification": BartForSequenceClassification, "text-generation": BartForCausalLM, + "text2text-generation": BartForConditionalGeneration, + "translation": BartForConditionalGeneration, "zero-shot": BartForSequenceClassification, } if is_torch_available() diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index f827503d0b..0f0f8f9793 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -199,8 +199,9 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester "conversational": TFBartForConditionalGeneration, "feature-extraction": TFBartModel, "summarization": TFBartForConditionalGeneration, - "text2text-generation": TFBartForConditionalGeneration, "text-classification": TFBartForSequenceClassification, + "text2text-generation": TFBartForConditionalGeneration, + "translation": TFBartForConditionalGeneration, "zero-shot": TFBartForSequenceClassification, } if is_tf_available() diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 88bc457c36..836cef014b 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -251,9 +251,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT "feature-extraction": BigBirdPegasusModel, "question-answering": BigBirdPegasusForQuestionAnswering, "summarization": BigBirdPegasusForConditionalGeneration, - "text2text-generation": BigBirdPegasusForConditionalGeneration, "text-classification": BigBirdPegasusForSequenceClassification, "text-generation": BigBirdPegasusForCausalLM, + "text2text-generation": BigBirdPegasusForConditionalGeneration, + "translation": BigBirdPegasusForConditionalGeneration, "zero-shot": BigBirdPegasusForSequenceClassification, } if is_torch_available() diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index c4b85c8a46..8762d11aaa 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -232,8 +232,9 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste "conversational": BlenderbotForConditionalGeneration, "feature-extraction": BlenderbotModel, "summarization": BlenderbotForConditionalGeneration, - "text2text-generation": BlenderbotForConditionalGeneration, "text-generation": BlenderbotForCausalLM, + "text2text-generation": BlenderbotForConditionalGeneration, + "translation": BlenderbotForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/blenderbot/test_modeling_tf_blenderbot.py b/tests/models/blenderbot/test_modeling_tf_blenderbot.py index e6b485346a..2db959e9f7 100644 --- a/tests/models/blenderbot/test_modeling_tf_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_tf_blenderbot.py @@ -185,6 +185,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te "feature-extraction": TFBlenderbotModel, "summarization": TFBlenderbotForConditionalGeneration, "text2text-generation": TFBlenderbotForConditionalGeneration, + "translation": TFBlenderbotForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index d2cfa94c19..f8d247ba12 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -226,8 +226,9 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline "conversational": BlenderbotSmallForConditionalGeneration, "feature-extraction": BlenderbotSmallModel, "summarization": BlenderbotSmallForConditionalGeneration, - "text2text-generation": BlenderbotSmallForConditionalGeneration, "text-generation": BlenderbotSmallForCausalLM, + "text2text-generation": BlenderbotSmallForConditionalGeneration, + "translation": BlenderbotSmallForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py index 057901f500..67a4f7ad7b 100644 --- a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py @@ -187,6 +187,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte "feature-extraction": TFBlenderbotSmallModel, "summarization": TFBlenderbotSmallForConditionalGeneration, "text2text-generation": TFBlenderbotSmallForConditionalGeneration, + "translation": TFBlenderbotSmallForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/convnextv2/test_modeling_convnextv2.py b/tests/models/convnextv2/test_modeling_convnextv2.py index 008481ab31..85ce5b1813 100644 --- a/tests/models/convnextv2/test_modeling_convnextv2.py +++ b/tests/models/convnextv2/test_modeling_convnextv2.py @@ -26,6 +26,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -164,7 +165,7 @@ class ConvNextV2ModelTester: @require_torch -class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase): +class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as ConvNextV2 does not use input_ids, inputs_embeds, attention_mask and seq_length. @@ -179,6 +180,11 @@ class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + pipeline_model_mapping = ( + {"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification} + if is_torch_available() + else {} + ) fx_compatible = False test_pruning = False diff --git a/tests/models/efficientnet/test_modeling_efficientnet.py b/tests/models/efficientnet/test_modeling_efficientnet.py index 687cb98c53..85e136e8f2 100644 --- a/tests/models/efficientnet/test_modeling_efficientnet.py +++ b/tests/models/efficientnet/test_modeling_efficientnet.py @@ -24,6 +24,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_av from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -122,13 +123,18 @@ class EfficientNetModelTester: @require_torch -class EfficientNetModelTest(ModelTesterMixin, unittest.TestCase): +class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as EfficientNet does not use input_ids, inputs_embeds, attention_mask and seq_length. """ all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification} + if is_torch_available() + else {} + ) fx_compatible = False test_pruning = False diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 90c21e6543..7a39e82aea 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -163,6 +163,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin "feature-extraction": FSMTModel, "summarization": FSMTForConditionalGeneration, "text2text-generation": FSMTForConditionalGeneration, + "translation": FSMTForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py index d0c8a090ec..83d70ab18e 100644 --- a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py +++ b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py @@ -96,7 +96,7 @@ class GPTSanJapaneseTester: def get_config(self): return GPTSanJapaneseConfig( - vocab_size=36000, + vocab_size=self.vocab_size, num_contexts=self.seq_length, d_model=self.hidden_size, d_ff=self.d_ff, diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index 271f997bee..493846b670 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -26,6 +26,7 @@ from transformers.testing_utils import is_flaky, require_torch, slow, torch_devi from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin TOLERANCE = 1e-4 @@ -177,9 +178,10 @@ class InformerModelTester: @require_torch -class InformerModelTest(ModelTesterMixin, unittest.TestCase): +class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (InformerModel, InformerForPrediction) if is_torch_available() else () all_generative_model_classes = (InformerForPrediction,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": InformerModel} if is_torch_available() else {} is_encoder_decoder = True test_pruning = False test_head_masking = False diff --git a/tests/models/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py index 31c78eacc6..b6dfc3256b 100644 --- a/tests/models/led/test_modeling_led.py +++ b/tests/models/led/test_modeling_led.py @@ -282,8 +282,9 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, "feature-extraction": LEDModel, "question-answering": LEDForQuestionAnswering, "summarization": LEDForConditionalGeneration, - "text2text-generation": LEDForConditionalGeneration, "text-classification": LEDForSequenceClassification, + "text2text-generation": LEDForConditionalGeneration, + "translation": LEDForConditionalGeneration, "zero-shot": LEDForSequenceClassification, } if is_torch_available() diff --git a/tests/models/led/test_modeling_tf_led.py b/tests/models/led/test_modeling_tf_led.py index cf7762ba22..7bac1ced83 100644 --- a/tests/models/led/test_modeling_tf_led.py +++ b/tests/models/led/test_modeling_tf_led.py @@ -199,6 +199,7 @@ class TFLEDModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase) "feature-extraction": TFLEDModel, "summarization": TFLEDForConditionalGeneration, "text2text-generation": TFLEDForConditionalGeneration, + "translation": TFLEDForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index 39a7e278ef..0f7ae0a272 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -510,6 +510,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix "feature-extraction": LongT5Model, "summarization": LongT5ForConditionalGeneration, "text2text-generation": LongT5ForConditionalGeneration, + "translation": LongT5ForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index e37a023307..d081041978 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -237,6 +237,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix "feature-extraction": M2M100Model, "summarization": M2M100ForConditionalGeneration, "text2text-generation": M2M100ForConditionalGeneration, + "translation": M2M100ForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 7b3eb1fb8b..fe5f860652 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -244,8 +244,9 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix "conversational": MarianMTModel, "feature-extraction": MarianModel, "summarization": MarianMTModel, - "text2text-generation": MarianMTModel, "text-generation": MarianForCausalLM, + "text2text-generation": MarianMTModel, + "translation": MarianMTModel, } if is_torch_available() else {} diff --git a/tests/models/marian/test_modeling_tf_marian.py b/tests/models/marian/test_modeling_tf_marian.py index 496e45e5c9..16b19b0f97 100644 --- a/tests/models/marian/test_modeling_tf_marian.py +++ b/tests/models/marian/test_modeling_tf_marian.py @@ -187,6 +187,7 @@ class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa "feature-extraction": TFMarianModel, "summarization": TFMarianMTModel, "text2text-generation": TFMarianMTModel, + "translation": TFMarianMTModel, } if is_tf_available() else {} diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 607babb1ba..b28d539b78 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -239,9 +239,10 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi "fill-mask": MBartForConditionalGeneration, "question-answering": MBartForQuestionAnswering, "summarization": MBartForConditionalGeneration, - "text2text-generation": MBartForConditionalGeneration, "text-classification": MBartForSequenceClassification, "text-generation": MBartForCausalLM, + "text2text-generation": MBartForConditionalGeneration, + "translation": MBartForConditionalGeneration, "zero-shot": MBartForSequenceClassification, } if is_torch_available() diff --git a/tests/models/mbart/test_modeling_tf_mbart.py b/tests/models/mbart/test_modeling_tf_mbart.py index c3e5721473..b143fc6877 100644 --- a/tests/models/mbart/test_modeling_tf_mbart.py +++ b/tests/models/mbart/test_modeling_tf_mbart.py @@ -190,6 +190,7 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas "feature-extraction": TFMBartModel, "summarization": TFMBartForConditionalGeneration, "text2text-generation": TFMBartForConditionalGeneration, + "translation": TFMBartForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/mega/test_modeling_mega.py b/tests/models/mega/test_modeling_mega.py index 7ea0efb83a..6b1ebce137 100644 --- a/tests/models/mega/test_modeling_mega.py +++ b/tests/models/mega/test_modeling_mega.py @@ -471,9 +471,11 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin pipeline_model_mapping = ( { "feature-extraction": MegaModel, + "fill-mask": MegaForMaskedLM, "question-answering": MegaForQuestionAnswering, "text-classification": MegaForSequenceClassification, "text-generation": MegaForCausalLM, + "token-classification": MegaForTokenClassification, "zero-shot": MegaForSequenceClassification, } if is_torch_available() diff --git a/tests/models/mgp_str/test_modeling_mgp_str.py b/tests/models/mgp_str/test_modeling_mgp_str.py index 990f78587d..1d972e22a3 100644 --- a/tests/models/mgp_str/test_modeling_mgp_str.py +++ b/tests/models/mgp_str/test_modeling_mgp_str.py @@ -25,6 +25,7 @@ from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -116,8 +117,9 @@ class MgpstrModelTester: @require_torch -class MgpstrModelTest(ModelTesterMixin, unittest.TestCase): +class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {} fx_compatible = False test_pruning = False diff --git a/tests/models/mvp/test_modeling_mvp.py b/tests/models/mvp/test_modeling_mvp.py index 7ade831659..e996a998a8 100644 --- a/tests/models/mvp/test_modeling_mvp.py +++ b/tests/models/mvp/test_modeling_mvp.py @@ -420,9 +420,10 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, "fill-mask": MvpForConditionalGeneration, "question-answering": MvpForQuestionAnswering, "summarization": MvpForConditionalGeneration, - "text2text-generation": MvpForConditionalGeneration, "text-classification": MvpForSequenceClassification, "text-generation": MvpForCausalLM, + "text2text-generation": MvpForConditionalGeneration, + "translation": MvpForConditionalGeneration, "zero-shot": MvpForSequenceClassification, } if is_torch_available() diff --git a/tests/models/nllb_moe/test_modeling_nllb_moe.py b/tests/models/nllb_moe/test_modeling_nllb_moe.py index 9f072a06d2..69a3d16ad9 100644 --- a/tests/models/nllb_moe/test_modeling_nllb_moe.py +++ b/tests/models/nllb_moe/test_modeling_nllb_moe.py @@ -255,6 +255,7 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "feature-extraction": NllbMoeModel, "summarization": NllbMoeForConditionalGeneration, "text2text-generation": NllbMoeForConditionalGeneration, + "translation": NllbMoeForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index de6c78df61..cb8b36c9af 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -242,8 +242,9 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "conversational": PegasusForConditionalGeneration, "feature-extraction": PegasusModel, "summarization": PegasusForConditionalGeneration, - "text2text-generation": PegasusForConditionalGeneration, "text-generation": PegasusForCausalLM, + "text2text-generation": PegasusForConditionalGeneration, + "translation": PegasusForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/pegasus/test_modeling_tf_pegasus.py b/tests/models/pegasus/test_modeling_tf_pegasus.py index a7f47f6756..6816cc34ef 100644 --- a/tests/models/pegasus/test_modeling_tf_pegasus.py +++ b/tests/models/pegasus/test_modeling_tf_pegasus.py @@ -185,6 +185,7 @@ class TFPegasusModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC "feature-extraction": TFPegasusModel, "summarization": TFPegasusForConditionalGeneration, "text2text-generation": TFPegasusForConditionalGeneration, + "translation": TFPegasusForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 545fe5149b..73c4ee62bf 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -204,6 +204,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM "feature-extraction": PegasusXModel, "summarization": PegasusXForConditionalGeneration, "text2text-generation": PegasusXForConditionalGeneration, + "translation": PegasusXForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 4c2b2b84ad..6ad226747a 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -224,9 +224,10 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix "conversational": PLBartForConditionalGeneration, "feature-extraction": PLBartModel, "summarization": PLBartForConditionalGeneration, - "text2text-generation": PLBartForConditionalGeneration, "text-classification": PLBartForSequenceClassification, "text-generation": PLBartForCausalLM, + "text2text-generation": PLBartForConditionalGeneration, + "translation": PLBartForConditionalGeneration, "zero-shot": PLBartForSequenceClassification, } if is_torch_available() diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index 1d4b45e23c..baf7351bf5 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -894,8 +894,9 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste "conversational": ProphetNetForConditionalGeneration, "feature-extraction": ProphetNetModel, "summarization": ProphetNetForConditionalGeneration, - "text2text-generation": ProphetNetForConditionalGeneration, "text-generation": ProphetNetForCausalLM, + "text2text-generation": ProphetNetForConditionalGeneration, + "translation": ProphetNetForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index a246f46487..f8730d8993 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel "feature-extraction": SwitchTransformersModel, "summarization": SwitchTransformersForConditionalGeneration, "text2text-generation": SwitchTransformersForConditionalGeneration, + "translation": SwitchTransformersForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index cfe1460895..3bc4d03a9c 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -528,6 +528,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, "feature-extraction": T5Model, "summarization": T5ForConditionalGeneration, "text2text-generation": T5ForConditionalGeneration, + "translation": T5ForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index 8a9112ff41..a1d784ae2f 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -250,6 +250,7 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): "feature-extraction": TFT5Model, "summarization": TFT5ForConditionalGeneration, "text2text-generation": TFT5ForConditionalGeneration, + "translation": TFT5ForConditionalGeneration, } if is_tf_available() else {} diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f0ba1a00f5..34531a1e12 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -277,7 +277,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( - {"automatic-speech-recognition": WhisperForConditionalGeneration, "feature-extraction": WhisperModel} + { + "audio-classification": WhisperForAudioClassification, + "automatic-speech-recognition": WhisperForConditionalGeneration, + "feature-extraction": WhisperModel, + } if is_torch_available() else {} )