Even more validation. (#20762)

* Even more validation.

* Fixing order.
This commit is contained in:
Nicolas Patry
2022-12-15 10:05:54 +01:00
committed by GitHub
parent 67acb07e9e
commit a9912d2fca
2 changed files with 8 additions and 0 deletions

View File

@@ -132,8 +132,12 @@ class TextGenerationPipeline(Pipeline):
if return_full_text is not None and return_type is None: if return_full_text is not None and return_type is None:
if return_text is not None: if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_full_text`") raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
if return_tensors is not None:
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None: if return_tensors is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.TENSORS return_type = ReturnType.TENSORS
if return_type is not None: if return_type is not None:
postprocess_params["return_type"] = return_type postprocess_params["return_type"] = return_type

View File

@@ -203,6 +203,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
outputs = text_generator("test", return_full_text=True, return_text=True) outputs = text_generator("test", return_full_text=True, return_text=True)
with self.assertRaises(ValueError):
outputs = text_generator("test", return_full_text=True, return_tensors=True)
with self.assertRaises(ValueError):
outputs = text_generator("test", return_text=True, return_tensors=True)
# Empty prompt is slighly special # Empty prompt is slighly special
# it requires BOS token to exist. # it requires BOS token to exist.