Add t5 to pipeline(task='summarization') (#3413)

* solve conflicts

* move warnings below

* incorporate changes

* add pad_to_max_length to pipelines

* add bug fix for T5 beam search

* add prefix patterns

* make style

* fix conflicts

* adapt pipelines for task specific parameters

* improve docstring

* remove unused patterns
This commit is contained in:
Patrick von Platen
2020-03-26 11:03:13 +01:00
committed by GitHub
parent ffcffebe85
commit 9c683ef01e
4 changed files with 120 additions and 42 deletions

View File

@@ -78,6 +78,9 @@ TF_FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(
@@ -252,10 +255,22 @@ class MonoColumnInputTestCase(unittest.TestCase):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
nlp = pipeline(task="summarization")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
for model, tokenizer in SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
@require_tf
def test_tf_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
for model, tokenizer in TF_SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
class MultiColumnInputTestCase(unittest.TestCase):