Add t5 to pipeline(task='summarization') (#3413)
* solve conflicts * move warnings below * incorporate changes * add pad_to_max_length to pipelines * add bug fix for T5 beam search * add prefix patterns * make style * fix conflicts * adapt pipelines for task specific parameters * improve docstring * remove unused patterns
This commit is contained in:
committed by
GitHub
parent
ffcffebe85
commit
9c683ef01e
@@ -78,6 +78,9 @@ TF_FILL_MASK_FINETUNED_MODELS = [
|
||||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||
]
|
||||
|
||||
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
|
||||
|
||||
|
||||
class MonoColumnInputTestCase(unittest.TestCase):
|
||||
def _test_mono_column_pipeline(
|
||||
@@ -252,10 +255,22 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["summary_text"]
|
||||
nlp = pipeline(task="summarization")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
for model, tokenizer in SUMMARIZATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_summarization(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["summary_text"]
|
||||
for model, tokenizer in TF_SUMMARIZATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user