From cf08830c2868c5335b1657f95779f370862e3a28 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 8 May 2020 14:30:05 +0200 Subject: [PATCH] [Pipeline, Generation] tf generation pipeline bug (#4217) * fix PR * move tests to correct place --- src/transformers/pipelines.py | 24 +++++++++++++++++++++++- tests/test_pipelines.py | 15 +++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 58907a38ab..bc28e84787 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline): # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # in https://github.com/rusiaaman/XLNet-gen#methodology # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e + PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the @@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline): the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing. """ + ALLOWED_MODELS = [ + "XLNetLMHeadModel", + "TransfoXLLMHeadModel", + "ReformerModelWithLMHead", + "GPT2LMHeadModel", + "OpenAIGPTLMHeadModel", + "CTRLLMHeadModel", + "TFXLNetLMHeadModel", + "TFTransfoXLLMHeadModel", + "TFGPT2LMHeadModel", + "TFOpenAIGPTLMHeadModel", + "TFCTRLLMHeadModel", + ] + def __call__( self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs ): + if self.model.__class__.__name__ not in self.ALLOWED_MODELS: + raise NotImplementedError( + "Generation is currently not supported for {}. Please select a model from {} for generation.".format( + self.model.__class__.__name__, self.ALLOWED_MODELS + ) + ) + text_inputs = self._args_parser(*args) results = [] @@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline): result = [] for generated_sequence in output_sequences: - generated_sequence = generated_sequence.tolist() + generated_sequence = generated_sequence.numpy().tolist() record = {} if return_tensors: record["generated_token_ids"] = generated_sequence diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index d0bac672e9..3c8baf2d00 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = { ("xlnet-base-cased", "xlnet-base-cased"), } +TF_TEXT_GENERATION_FINETUNED_MODELS = { + ("gpt2", "gpt2"), + ("xlnet-base-cased", "xlnet-base-cased"), +} + FILL_MASK_FINETUNED_MODELS = [ (("distilroberta-base", {"use_fast": False}), "distilroberta-base", None), ] @@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase): nlp, valid_inputs, invalid_inputs, {}, ) + @require_tf + def test_tf_text_generation(self): + valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]] + invalid_inputs = [None] + for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS: + nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf") + self._test_mono_column_pipeline( + nlp, valid_inputs, invalid_inputs, {}, + ) + class MultiColumnInputTestCase(unittest.TestCase): def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):