further cleanup
This commit is contained in:
@@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||
|
||||
|
||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
kwargs = {"language": None, "mask_token": None}
|
||||
kwargs = {"language": None, "mask_token_id": None}
|
||||
|
||||
# Set the language
|
||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||
@@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
# XLM masked-language modeling (MLM) models need masked token
|
||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
kwargs["mask_token"] = tokenizer.mask_token_id
|
||||
kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||
|
||||
return prompt_text, kwargs
|
||||
|
||||
@@ -204,14 +204,13 @@ def main():
|
||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
||||
|
||||
output_sequences = model.decode(
|
||||
prompt_ids=encoded_prompt,
|
||||
output_sequences = model.generate(
|
||||
intput_ids=encoded_prompt,
|
||||
length=args.length,
|
||||
temperature=args.temperature,
|
||||
k=args.k,
|
||||
p=args.p,
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
device=args.device,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user