[Generation] Generation should allow to start with empty prompt (#3993)
* fix empty prompt * fix length in generation pipeline
This commit is contained in:
committed by
GitHub
parent
52679fbc2e
commit
180585741c
@@ -221,8 +221,13 @@ def main():
|
|||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||||
encoded_prompt = encoded_prompt.to(args.device)
|
encoded_prompt = encoded_prompt.to(args.device)
|
||||||
|
|
||||||
|
if encoded_prompt.size()[-1] == 0:
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
input_ids = encoded_prompt
|
||||||
|
|
||||||
output_sequences = model.generate(
|
output_sequences = model.generate(
|
||||||
input_ids=encoded_prompt,
|
input_ids=input_ids,
|
||||||
max_length=args.length + len(encoded_prompt[0]),
|
max_length=args.length + len(encoded_prompt[0]),
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_k=args.k,
|
top_k=args.k,
|
||||||
|
|||||||
@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
inputs = self._parse_and_tokenize(prompt_text)
|
inputs = self._parse_and_tokenize(prompt_text)
|
||||||
|
|
||||||
if self.framework == "pt":
|
# set input_ids to None to allow empty prompt
|
||||||
|
if inputs["input_ids"].shape[-1] == 0:
|
||||||
|
inputs["input_ids"] = None
|
||||||
|
inputs["attention_mask"] = None
|
||||||
|
|
||||||
|
if self.framework == "pt" and inputs["input_ids"] is not None:
|
||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
|
||||||
input_ids = inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
|
|
||||||
# Ensure that batch size = 1 (batch generation not allowed for now)
|
# Ensure that batch size = 1 (batch generation not allowed for now)
|
||||||
assert (
|
assert (
|
||||||
input_ids.shape[0] == 1
|
input_ids is None or input_ids.shape[0] == 1
|
||||||
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
|
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
|
||||||
|
|
||||||
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||||
@@ -590,19 +595,19 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
||||||
record["generated_text"] = (
|
if input_ids is None:
|
||||||
prompt_text
|
prompt_length = 0
|
||||||
+ text[
|
else:
|
||||||
len(
|
prompt_length = len(
|
||||||
self.tokenizer.decode(
|
self.tokenizer.decode(
|
||||||
input_ids[0],
|
input_ids[0],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
)
|
)
|
||||||
) :
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
record["generated_text"] = prompt_text + text[prompt_length:]
|
||||||
|
|
||||||
result.append(record)
|
result.append(record)
|
||||||
results += [result]
|
results += [result]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user