Update summarization run_pipeline_test (#20623)
* update summarization run_pipeline_test * update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,11 +17,7 @@ import unittest
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
LEDConfig,
|
|
||||||
LongT5Config,
|
|
||||||
SummarizationPipeline,
|
SummarizationPipeline,
|
||||||
SwitchTransformersConfig,
|
|
||||||
T5Config,
|
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||||
@@ -55,7 +51,17 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
|||||||
)
|
)
|
||||||
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
|
||||||
|
|
||||||
if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)):
|
model_can_handle_longer_seq = [
|
||||||
|
"SwitchTransformersConfig",
|
||||||
|
"T5Config",
|
||||||
|
"LongT5Config",
|
||||||
|
"LEDConfig",
|
||||||
|
"PegasusXConfig",
|
||||||
|
"FSMTConfig",
|
||||||
|
"M2M100Config",
|
||||||
|
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
|
||||||
|
]
|
||||||
|
if model.config.__class__.__name__ not in model_can_handle_longer_seq:
|
||||||
# Switch Transformers, LED, T5, LongT5 can handle it.
|
# Switch Transformers, LED, T5, LongT5 can handle it.
|
||||||
# Too long.
|
# Too long.
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
|
|||||||
Reference in New Issue
Block a user