From c2c2ca0fdba6f16f2e66d0a21152aa4ae493ca78 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 3 Oct 2019 17:18:48 -0400 Subject: [PATCH] Added XLM to run_generation, with prompt language selection. --- examples/run_generation.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 9e98a9e870..a70a0e7842 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -26,12 +26,13 @@ import torch import torch.nn.functional as F import numpy as np -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig +from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer +from transformers import XLMWithLMHeadModel, XLMTokenizer logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -41,13 +42,14 @@ logger = logging.getLogger(__name__) MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig)), ()) MODEL_CLASSES = { 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), + 'xlm': (XLMWithLMHeadModel, XLMTokenizer), } # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits -def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'): +def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, + xlm_lang=None, device='cpu'): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0).repeat(num_samples, 1) generated = context @@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= target_mapping[0, 0, -1] = 1.0 # predict last token inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} + if xlm_lang is not None: + inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1) + outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] / temperature filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) @@ -137,6 +143,7 @@ def main(): help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") + parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) @@ -168,6 +175,17 @@ def main(): print(args) while True: + xlm_lang = None + # XLM Language usage detailed in the issues #1414 + if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id'): + if args.xlm_lang: + language = args.xlm_lang + else: + language = None + while language not in tokenizer.lang2id.keys(): + language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") + xlm_lang = tokenizer.lang2id[language] + raw_text = args.prompt if args.prompt else input("Model prompt >>> ") if args.model_type in ["transfo-xl", "xlnet"]: # Models with memory likes to have a long prompt for short inputs. @@ -180,11 +198,12 @@ def main(): temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, - device=args.device, is_xlnet=bool(args.model_type == "xlnet"), + xlm_lang=xlm_lang, + device=args.device, ) out = out[0, len(context_tokens):].tolist() - text = tokenizer.decode(out, clean_up_tokenization_spaces=True) + text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True) print(text) if args.prompt: break