diff --git a/examples/pytorch/text-generation/run_generation.py b/examples/pytorch/text-generation/run_generation.py index 42cd9528e1..9943a4f54a 100755 --- a/examples/pytorch/text-generation/run_generation.py +++ b/examples/pytorch/text-generation/run_generation.py @@ -38,8 +38,6 @@ from transformers import ( OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPTForCausalLM, - TransfoXLLMHeadModel, - TransfoXLTokenizer, XLMTokenizer, XLMWithLMHeadModel, XLNetLMHeadModel, @@ -62,7 +60,6 @@ MODEL_CLASSES = { "ctrl": (CTRLLMHeadModel, CTRLTokenizer), "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), "xlnet": (XLNetLMHeadModel, XLNetTokenizer), - "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), "xlm": (XLMWithLMHeadModel, XLMTokenizer), "gptj": (GPTJForCausalLM, AutoTokenizer), "bloom": (BloomForCausalLM, BloomTokenizerFast), @@ -368,10 +365,7 @@ def main(): prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) - if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - tokenizer_kwargs = {"add_space_before_punct_symbol": True} - else: - tokenizer_kwargs = {} + tokenizer_kwargs = {} encoded_prompt = tokenizer.encode( preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs