[Pipeline, Generation] tf generation pipeline bug (#4217)
* fix PR * move tests to correct place
This commit is contained in:
committed by
GitHub
parent
8bf7312654
commit
cf08830c28
@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
|
||||
("xlnet-base-cased", "xlnet-base-cased"),
|
||||
}
|
||||
|
||||
TF_TEXT_GENERATION_FINETUNED_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("xlnet-base-cased", "xlnet-base-cased"),
|
||||
}
|
||||
|
||||
FILL_MASK_FINETUNED_MODELS = [
|
||||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||
]
|
||||
@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp, valid_inputs, invalid_inputs, {},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_text_generation(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [None]
|
||||
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, {},
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
|
||||
Reference in New Issue
Block a user