[Generation] Generation should allow to start with empty prompt (#3993)
* fix empty prompt * fix length in generation pipeline
This commit is contained in:
committed by
GitHub
parent
52679fbc2e
commit
180585741c
@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
|
||||
else:
|
||||
inputs = self._parse_and_tokenize(prompt_text)
|
||||
|
||||
if self.framework == "pt":
|
||||
# set input_ids to None to allow empty prompt
|
||||
if inputs["input_ids"].shape[-1] == 0:
|
||||
inputs["input_ids"] = None
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
if self.framework == "pt" and inputs["input_ids"] is not None:
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
# Ensure that batch size = 1 (batch generation not allowed for now)
|
||||
assert (
|
||||
input_ids.shape[0] == 1
|
||||
input_ids is None or input_ids.shape[0] == 1
|
||||
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
|
||||
|
||||
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||
@@ -590,18 +595,18 @@ class TextGenerationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
||||
record["generated_text"] = (
|
||||
prompt_text
|
||||
+ text[
|
||||
len(
|
||||
self.tokenizer.decode(
|
||||
input_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
) :
|
||||
]
|
||||
)
|
||||
if input_ids is None:
|
||||
prompt_length = 0
|
||||
else:
|
||||
prompt_length = len(
|
||||
self.tokenizer.decode(
|
||||
input_ids[0],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
)
|
||||
|
||||
record["generated_text"] = prompt_text + text[prompt_length:]
|
||||
|
||||
result.append(record)
|
||||
results += [result]
|
||||
|
||||
Reference in New Issue
Block a user