[Pipeline, Generation] tf generation pipeline bug (#4217)

* fix PR

* move tests to correct place
This commit is contained in:
Patrick von Platen
2020-05-08 14:30:05 +02:00
committed by GitHub
parent 8bf7312654
commit cf08830c28
2 changed files with 38 additions and 1 deletions

View File

@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
("xlnet-base-cased", "xlnet-base-cased"),
}
TF_TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, {},
)
@require_tf
def test_tf_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):