[pipelines] Text2TextGenerationPipeline (#6744)
* add Text2TextGenerationPipeline * remove max length warning * remove comments * remove input_length * fix typo * add tests * use TFAutoModelForSeq2SeqLM * doc * typo * add the doc below TextGenerationPipeline * doc nit * style * delete comment
This commit is contained in:
@@ -28,6 +28,9 @@ TRANSLATION_FINETUNED_MODELS = [
|
||||
]
|
||||
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
|
||||
|
||||
TEXT2TEXT_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
|
||||
TF_TEXT2TEXT_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
|
||||
|
||||
DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"]
|
||||
|
||||
expected_fill_mask_result = [
|
||||
@@ -394,6 +397,28 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp = pipeline(task=task, model=model, tokenizer=model, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys, invalid_inputs=invalid_inputs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_text2text(self):
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["generated_text"]
|
||||
for model_name in TEXT2TEXT_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text2text-generation", model=model_name, tokenizer=model_name)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
VALID_INPUTS,
|
||||
mandatory_keys,
|
||||
invalid_inputs,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_text2text(self):
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["generated_text"]
|
||||
for model in TEXT2TEXT_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text2text-generation", model=model, tokenizer=model, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys, invalid_inputs=invalid_inputs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_text_generation(self):
|
||||
for model_name in TEXT_GENERATION_FINETUNED_MODELS:
|
||||
|
||||
Reference in New Issue
Block a user