From 180585741cf3cdd6890cb99610923a8ae9691220 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Apr 2020 14:33:15 +0200 Subject: [PATCH] [Generation] Generation should allow to start with empty prompt (#3993) * fix empty prompt * fix length in generation pipeline --- examples/run_generation.py | 7 ++++++- src/transformers/pipelines.py | 33 +++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 3f90ee5833..b4b91542e9 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -221,8 +221,13 @@ def main(): encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = encoded_prompt.to(args.device) + if encoded_prompt.size()[-1] == 0: + input_ids = None + else: + input_ids = encoded_prompt + output_sequences = model.generate( - input_ids=encoded_prompt, + input_ids=input_ids, max_length=args.length + len(encoded_prompt[0]), temperature=args.temperature, top_k=args.k, diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 6e4f7c5b2b..5871890b9f 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline): else: inputs = self._parse_and_tokenize(prompt_text) - if self.framework == "pt": + # set input_ids to None to allow empty prompt + if inputs["input_ids"].shape[-1] == 0: + inputs["input_ids"] = None + inputs["attention_mask"] = None + + if self.framework == "pt" and inputs["input_ids"] is not None: inputs = self.ensure_tensor_on_device(**inputs) input_ids = inputs["input_ids"] # Ensure that batch size = 1 (batch generation not allowed for now) assert ( - input_ids.shape[0] == 1 + input_ids is None or input_ids.shape[0] == 1 ), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information." output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL @@ -590,18 +595,18 @@ class TextGenerationPipeline(Pipeline): ) # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used - record["generated_text"] = ( - prompt_text - + text[ - len( - self.tokenizer.decode( - input_ids[0], - skip_special_tokens=True, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - ) - ) : - ] - ) + if input_ids is None: + prompt_length = 0 + else: + prompt_length = len( + self.tokenizer.decode( + input_ids[0], + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + ) + + record["generated_text"] = prompt_text + text[prompt_length:] result.append(record) results += [result]