From a9912d2fca8885b569a772f29ef9c1f68b6e9089 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Dec 2022 10:05:54 +0100 Subject: [PATCH] Even more validation. (#20762) * Even more validation. * Fixing order. --- src/transformers/pipelines/text_generation.py | 4 ++++ tests/pipelines/test_pipelines_text_generation.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 58c0edc3f0..b19d58f4ff 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -132,8 +132,12 @@ class TextGenerationPipeline(Pipeline): if return_full_text is not None and return_type is None: if return_text is not None: raise ValueError("`return_text` is mutually exclusive with `return_full_text`") + if return_tensors is not None: + raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`") return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT if return_tensors is not None and return_type is None: + if return_text is not None: + raise ValueError("`return_text` is mutually exclusive with `return_tensors`") return_type = ReturnType.TENSORS if return_type is not None: postprocess_params["return_type"] = return_type diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 4796241109..92bda4f810 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -203,6 +203,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM with self.assertRaises(ValueError): outputs = text_generator("test", return_full_text=True, return_text=True) + with self.assertRaises(ValueError): + outputs = text_generator("test", return_full_text=True, return_tensors=True) + with self.assertRaises(ValueError): + outputs = text_generator("test", return_text=True, return_tensors=True) # Empty prompt is slighly special # it requires BOS token to exist.