From 9d9b872b66f9ab9b7b7c73f2c00985dd92c4121b Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 6 Jul 2020 12:17:05 -0400 Subject: [PATCH] The `add_space_before_punct_symbol` is only for TransfoXL (#5549) --- examples/text-generation/run_generation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index b4b91542e9..40017733ec 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -214,8 +214,14 @@ def main(): if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) + + if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + tokenizer_kwargs = {"add_space_before_punct_symbol": True} + else: + tokenizer_kwargs = {} + encoded_prompt = tokenizer.encode( - preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True + preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs ) else: encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")