From cec5f7abd1062df18c32109b5c1d19a9bcc14174 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 7 Dec 2022 15:46:12 +0100 Subject: [PATCH] Update summarization `run_pipeline_test` (#20623) * update summarization run_pipeline_test * update Co-authored-by: ydshieh --- tests/pipelines/test_pipelines_summarization.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index c4c646cee9..eb688d69c7 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -17,11 +17,7 @@ import unittest from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - LEDConfig, - LongT5Config, SummarizationPipeline, - SwitchTransformersConfig, - T5Config, pipeline, ) from transformers.testing_utils import require_tf, require_torch, slow, torch_device @@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) - if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)): + model_can_handle_longer_seq = [ + "SwitchTransformersConfig", + "T5Config", + "LongT5Config", + "LEDConfig", + "PegasusXConfig", + "FSMTConfig", + "M2M100Config", + "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) + ] + if model.config.__class__.__name__ not in model_can_handle_longer_seq: # Switch Transformers, LED, T5, LongT5 can handle it. # Too long. with self.assertRaises(Exception):