Improve special_token_id logic in run_generation.py and add tests (#2885)
* improving generation * finalized special token behaviour for no_beam_search generation * solved modeling_utils merge conflict * solve merge conflicts in modeling_utils.py * add run_generation improvements from PR #2749 * adapted language generation to not use hardcoded -1 if no padding token is available * remove the -1 removal as hard coded -1`s are not necessary anymore * add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown * add slow language generation tests for pretrained models using hardcoded output with pytorch seed * delete ipdb * check that all generated tokens are valid * renaming * renaming Generation -> Generate * make style * updated so that generate_beam_search has same token behavior than generate_no_beam_search * consistent return format for run_generation.py * deleted pretrain lm generate tests -> will be added in another PR * cleaning of unused if statements and renaming * run_generate will always return an iterable * make style * consistent renaming * improve naming, make sure generate function always returns the same tensor, add docstring * add slow tests for all lmhead models * make style and improve example comments modeling_utils * better naming and refactoring in modeling_utils * improving generation * finalized special token behaviour for no_beam_search generation * solved modeling_utils merge conflict * solve merge conflicts in modeling_utils.py * add run_generation improvements from PR #2749 * adapted language generation to not use hardcoded -1 if no padding token is available * remove the -1 removal as hard coded -1`s are not necessary anymore * add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown * add slow language generation tests for pretrained models using hardcoded output with pytorch seed * delete ipdb * check that all generated tokens are valid * renaming * renaming Generation -> Generate * make style * updated so that generate_beam_search has same token behavior than generate_no_beam_search * consistent return format for run_generation.py * deleted pretrain lm generate tests -> will be added in another PR * cleaning of unused if statements and renaming * run_generate will always return an iterable * make style * consistent renaming * improve naming, make sure generate function always returns the same tensor, add docstring * add slow tests for all lmhead models * make style and improve example comments modeling_utils * better naming and refactoring in modeling_utils * changed fast random lm generation testing design to more general one * delete in old testing design in gpt2 * correct old variable name * temporary fix for encoder_decoder lm generation tests - has to be updated when t5 is fixed * adapted all fast random generate tests to new design * better warning description in modeling_utils * better comment * better comment and error message Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c749a543fa
commit
fc38d4c86f
@@ -106,6 +106,8 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
language = None
|
||||
while language not in available_languages:
|
||||
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
|
||||
|
||||
model.config.lang_id = model.config.lang2id[language]
|
||||
# kwargs["language"] = tokenizer.lang2id[language]
|
||||
|
||||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||
@@ -119,12 +121,12 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
return prompt_text
|
||||
|
||||
|
||||
PREPROCESSING_FUNCTIONS = {
|
||||
@@ -183,6 +185,7 @@ def main():
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
@@ -210,28 +213,48 @@ def main():
|
||||
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
else:
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(args.device)
|
||||
|
||||
output_sequences = model.generate(
|
||||
input_ids=encoded_prompt,
|
||||
max_length=args.length,
|
||||
max_length=args.length + len(encoded_prompt[0]),
|
||||
temperature=args.temperature,
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
do_sample=True,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
)
|
||||
|
||||
# Batch size == 1. to add more examples please use num_return_sequences > 1
|
||||
generated_sequence = output_sequences[0].tolist()
|
||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
# Remove the batch dimension when returning multiple sequences
|
||||
if len(output_sequences.shape) > 2:
|
||||
output_sequences.squeeze_()
|
||||
|
||||
print(text)
|
||||
generated_sequences = []
|
||||
|
||||
return text
|
||||
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
||||
print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
|
||||
generated_sequence = generated_sequence.tolist()
|
||||
|
||||
# Decode text
|
||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||
|
||||
# Remove all text after the stop token
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
|
||||
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
|
||||
total_sequence = (
|
||||
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
||||
)
|
||||
|
||||
generated_sequences.append(total_sequence)
|
||||
print(total_sequence)
|
||||
|
||||
return generated_sequences
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -97,4 +97,4 @@ class ExamplesTests(unittest.TestCase):
|
||||
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result), 10)
|
||||
self.assertGreaterEqual(len(result[0]), 10)
|
||||
|
||||
Reference in New Issue
Block a user