[Generation] Generation should allow to start with empty prompt (#3993)

* fix empty prompt

* fix length in generation pipeline
This commit is contained in:
Patrick von Platen
2020-04-28 14:33:15 +02:00
committed by GitHub
parent 52679fbc2e
commit 180585741c
2 changed files with 25 additions and 15 deletions

View File

@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
else:
inputs = self._parse_and_tokenize(prompt_text)
if self.framework == "pt":
# set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0:
inputs["input_ids"] = None
inputs["attention_mask"] = None
if self.framework == "pt" and inputs["input_ids"] is not None:
inputs = self.ensure_tensor_on_device(**inputs)
input_ids = inputs["input_ids"]
# Ensure that batch size = 1 (batch generation not allowed for now)
assert (
input_ids.shape[0] == 1
input_ids is None or input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
@@ -590,18 +595,18 @@ class TextGenerationPipeline(Pipeline):
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
record["generated_text"] = (
prompt_text
+ text[
len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
) :
]
)
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
record["generated_text"] = prompt_text + text[prompt_length:]
result.append(record)
results += [result]