feat: allow prefix for any generative model (#5885)
* feat: allow padding_text for any generative model * docs(pipelines.py): correct typo * Update src/transformers/pipelines.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * feat: rename padding_text to prefix * fix: cannot tokenize empty text * fix: pass prefix arg to pipeline * test: add prefix to text-generetation pipeline * style: fix style * style: clean code and variable name more explicit * set arg docstring to optional Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -61,7 +61,7 @@ MODEL_CLASSES = {
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
(except for Alexei and Maria) are discovered.
|
||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
remainder of the story. 1883 Western Siberia,
|
||||
@@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
|
||||
prompt_text = prefix + prompt_text
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
|
||||
prompt_text = prefix + prompt_text
|
||||
return prompt_text
|
||||
|
||||
|
||||
@@ -182,7 +184,8 @@ def main():
|
||||
parser.add_argument("--k", type=int, default=0)
|
||||
parser.add_argument("--p", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
|
||||
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
|
||||
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
|
||||
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
@@ -241,7 +244,8 @@ def main():
|
||||
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
|
||||
)
|
||||
else:
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
prefix = args.prefix if args.prefix else args.padding_text
|
||||
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(args.device)
|
||||
|
||||
if encoded_prompt.size()[-1] == 0:
|
||||
|
||||
Reference in New Issue
Block a user