[pipelines] Text2TextGenerationPipeline (#6744)

* add Text2TextGenerationPipeline

* remove max length warning

* remove comments

* remove input_length

* fix typo

* add tests

* use TFAutoModelForSeq2SeqLM

* doc

* typo

* add the doc below TextGenerationPipeline

* doc nit

* style

* delete comment
This commit is contained in:
Suraj Patil
2020-09-02 17:04:35 +05:30
committed by GitHub
parent 6b24281229
commit 4230d30f77
4 changed files with 139 additions and 1 deletions

View File

@@ -28,6 +28,9 @@ TRANSLATION_FINETUNED_MODELS = [
]
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
TEXT2TEXT_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
TF_TEXT2TEXT_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"]
expected_fill_mask_result = [
@@ -394,6 +397,28 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp = pipeline(task=task, model=model, tokenizer=model, framework="tf")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys, invalid_inputs=invalid_inputs)
@require_torch
def test_torch_text2text(self):
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["generated_text"]
for model_name in TEXT2TEXT_FINETUNED_MODELS:
nlp = pipeline(task="text2text-generation", model=model_name, tokenizer=model_name)
self._test_mono_column_pipeline(
nlp,
VALID_INPUTS,
mandatory_keys,
invalid_inputs,
)
@require_tf
@slow
def test_tf_text2text(self):
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["generated_text"]
for model in TEXT2TEXT_FINETUNED_MODELS:
nlp = pipeline(task="text2text-generation", model=model, tokenizer=model, framework="tf")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys, invalid_inputs=invalid_inputs)
@require_torch
def test_torch_text_generation(self):
for model_name in TEXT_GENERATION_FINETUNED_MODELS: