fixing run_generation example - using torch.no_grad
This commit is contained in:
@@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||
logger.info(
|
||||
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
||||
)
|
||||
return prompt_text, {}
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
kwargs = {"language": None, "mask_token_id": None}
|
||||
# kwargs = {"language": None, "mask_token_id": None}
|
||||
|
||||
# Set the language
|
||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||
@@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
+ str(list(available_languages))
|
||||
+ " >>> "
|
||||
)
|
||||
kwargs["language"] = tokenizer.lang2id[language]
|
||||
# kwargs["language"] = tokenizer.lang2id[language]
|
||||
|
||||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||
# XLM masked-language modeling (MLM) models need masked token
|
||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||
# is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||
# if is_xlm_mlm:
|
||||
# kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||
|
||||
return prompt_text, kwargs
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
@@ -179,8 +180,8 @@ def main():
|
||||
try:
|
||||
args.model_type = args.model_type.lower()
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
except KeyError as ke:
|
||||
raise ke(
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
||||
)
|
||||
|
||||
@@ -197,10 +198,9 @@ def main():
|
||||
|
||||
# Different models need different input formatting and/or extra arguments
|
||||
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||
model_kwargs = {}
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||
|
||||
output_sequences = model.generate(
|
||||
@@ -210,14 +210,11 @@ def main():
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
generated_sequence = output_sequences.tolist()[
|
||||
encoded_prompt.size(1) :
|
||||
] # adapted to case where num_samples > 1
|
||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
generated_sequence = output_sequences.tolist()
|
||||
text = [tokenizer.decode(seq, clean_up_tokenization_spaces=True) for seq in generated_sequence]
|
||||
# text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
|
||||
print(text)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user