diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 85bf673d85..0ef9858286 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -66,3 +66,9 @@ SummarizationPipeline ========================================== .. autoclass:: transformers.SummarizationPipeline + + +TextGenerationPipeline +========================================== + +.. autoclass:: transformers.TextGenerationPipeline diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 88c48020d4..0a2b93f4c8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -117,6 +117,7 @@ from .pipelines import ( QuestionAnsweringPipeline, SummarizationPipeline, TextClassificationPipeline, + TextGenerationPipeline, TokenClassificationPipeline, TranslationPipeline, pipeline, diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 761930c367..6e4f7c5b2b 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -520,6 +520,98 @@ class FeatureExtractionPipeline(Pipeline): return super().__call__(*args, **kwargs).tolist() +class TextGenerationPipeline(Pipeline): + """ + Language generation pipeline using any ModelWithLMHead head. This pipeline predicts the words that will follow a specified text prompt. + + This language generation pipeline can currently be loaded from the :func:`~transformers.pipeline` method using + the following task identifier(s): + + - "text-generation", for generating text from a specified prompt. + + The models that this pipeline can use are models that have been trained with an autoregressive language modeling objective, + which includes the uni-directional models in the library (e.g. gpt2). + See the list of available community models on + `huggingface.co/models `__. + """ + + # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia + # in https://github.com/rusiaaman/XLNet-gen#methodology + # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e + PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family + (except for Alexei and Maria) are discovered. + The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the + remainder of the story. 1883 Western Siberia, + a young Grigori Rasputin is asked by his father and a group of men to perform magic. + Rasputin has a vision and denounces one of the men as a horse thief. Although his + father initially slaps him for making such an accusation, Rasputin watches as the + man is chased outside and beaten. Twenty years later, Rasputin sees a vision of + the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, + with people, even a bishop, begging for his blessing. """ + + def __call__( + self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + ): + text_inputs = self._args_parser(*texts) + + results = [] + for prompt_text in text_inputs: + # Manage correct placement of the tensors + with self.device_placement(): + if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: + inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text) + else: + inputs = self._parse_and_tokenize(prompt_text) + + if self.framework == "pt": + inputs = self.ensure_tensor_on_device(**inputs) + + input_ids = inputs["input_ids"] + + # Ensure that batch size = 1 (batch generation not allowed for now) + assert ( + input_ids.shape[0] == 1 + ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information." + + output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL + + result = [] + for generated_sequence in output_sequences: + generated_sequence = generated_sequence.tolist() + record = {} + if return_tensors: + record["generated_token_ids"] = generated_sequence + if return_text: + # Decode text + text = self.tokenizer.decode( + generated_sequence, + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + + # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used + record["generated_text"] = ( + prompt_text + + text[ + len( + self.tokenizer.decode( + input_ids[0], + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + ) : + ] + ) + + result.append(record) + results += [result] + + if len(results) == 1: + return results[0] + + return results + + class TextClassificationPipeline(Pipeline): """ Text classification pipeline using ModelForSequenceClassification head. See the @@ -1456,6 +1548,12 @@ SUPPORTED_TASKS = { "tokenizer": ("t5-base", {"use_fast": False}), }, }, + "text-generation": { + "impl": TextGenerationPipeline, + "tf": TFAutoModelWithLMHead if is_tf_available() else None, + "pt": AutoModelWithLMHead if is_torch_available() else None, + "default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"}, + }, } diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 1b56ef637d..0f91b813d7 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -60,6 +60,11 @@ TEXT_CLASSIF_FINETUNED_MODELS = { ) } +TEXT_GENERATION_FINETUNED_MODELS = { + ("gpt2", "gpt2"), + ("xlnet-base-cased", "xlnet-base-cased"), +} + FILL_MASK_FINETUNED_MODELS = [ (("distilroberta-base", {"use_fast": False}), "distilroberta-base", None), ] @@ -293,6 +298,16 @@ class MonoColumnInputTestCase(unittest.TestCase): nlp, valid_inputs, invalid_inputs, mandatory_keys, ) + @require_torch + def test_text_generation(self): + valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]] + invalid_inputs = [None] + for model, tokenizer in TEXT_GENERATION_FINETUNED_MODELS: + nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="pt") + self._test_mono_column_pipeline( + nlp, valid_inputs, invalid_inputs, {}, + ) + class MultiColumnInputTestCase(unittest.TestCase): def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]): @@ -371,6 +386,7 @@ class PipelineCommonTests(unittest.TestCase): "translation_en_to_fr", "translation_en_to_de", "translation_en_to_ro", + "text-generation", ) @slow