From a12c5cbcd89d3eb11df594c2a0874324ed397f62 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 13 Dec 2022 13:42:36 +0100 Subject: [PATCH] Change a logic in pipeline test regarding TF (#20710) * Fix the pipeline test regarding TF * Fix the pipeline test regarding TF * update comment Co-authored-by: ydshieh --- .../pipelines/test_pipelines_summarization.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index eb688d69c7..781716b5ba 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -18,9 +18,10 @@ from transformers import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, SummarizationPipeline, + TFPreTrainedModel, pipeline, ) -from transformers.testing_utils import require_tf, require_torch, slow, torch_device +from transformers.testing_utils import get_gpu_count, require_tf, require_torch, slow, torch_device from transformers.tokenization_utils import TruncationStrategy from .test_pipelines_common import ANY, PipelineTestCaseMeta @@ -51,6 +52,7 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) + # Some models (Switch Transformers, LED, T5, LongT5, etc) can handle long sequences. model_can_handle_longer_seq = [ "SwitchTransformersConfig", "T5Config", @@ -62,10 +64,16 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) ] if model.config.__class__.__name__ not in model_can_handle_longer_seq: - # Switch Transformers, LED, T5, LongT5 can handle it. - # Too long. - with self.assertRaises(Exception): - outputs = summarizer("This " * 1000) + # Too long and exception is expected. + # For TF models, if the weights are initialized in GPU context, we won't get expected index error from + # the embedding layer. + if not ( + isinstance(model, TFPreTrainedModel) + and get_gpu_count() > 0 + and len(summarizer.model.trainable_weights) > 0 + ): + with self.assertRaises(Exception): + outputs = summarizer("This " * 1000) outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST) @require_torch