Added XLM to run_generation, with prompt language selection.
This commit is contained in:
@@ -26,12 +26,13 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
|
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig
|
||||||
|
|
||||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||||
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||||
|
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
@@ -41,13 +42,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig)), ())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||||||
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||||
|
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
@@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False,
|
||||||
|
xlm_lang=None, device='cpu'):
|
||||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||||
generated = context
|
generated = context
|
||||||
@@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
|||||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||||
|
|
||||||
|
if xlm_lang is not None:
|
||||||
|
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1)
|
||||||
|
|
||||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||||
next_token_logits = outputs[0][0, -1, :] / temperature
|
next_token_logits = outputs[0][0, -1, :] / temperature
|
||||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
@@ -137,6 +143,7 @@ def main():
|
|||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
parser.add_argument("--prompt", type=str, default="")
|
parser.add_argument("--prompt", type=str, default="")
|
||||||
parser.add_argument("--padding_text", type=str, default="")
|
parser.add_argument("--padding_text", type=str, default="")
|
||||||
|
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_k", type=int, default=0)
|
parser.add_argument("--top_k", type=int, default=0)
|
||||||
@@ -168,6 +175,17 @@ def main():
|
|||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
while True:
|
while True:
|
||||||
|
xlm_lang = None
|
||||||
|
# XLM Language usage detailed in the issues #1414
|
||||||
|
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id'):
|
||||||
|
if args.xlm_lang:
|
||||||
|
language = args.xlm_lang
|
||||||
|
else:
|
||||||
|
language = None
|
||||||
|
while language not in tokenizer.lang2id.keys():
|
||||||
|
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||||
|
xlm_lang = tokenizer.lang2id[language]
|
||||||
|
|
||||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||||
# Models with memory likes to have a long prompt for short inputs.
|
# Models with memory likes to have a long prompt for short inputs.
|
||||||
@@ -180,11 +198,12 @@ def main():
|
|||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
device=args.device,
|
|
||||||
is_xlnet=bool(args.model_type == "xlnet"),
|
is_xlnet=bool(args.model_type == "xlnet"),
|
||||||
|
xlm_lang=xlm_lang,
|
||||||
|
device=args.device,
|
||||||
)
|
)
|
||||||
out = out[0, len(context_tokens):].tolist()
|
out = out[0, len(context_tokens):].tolist()
|
||||||
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
|
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
||||||
print(text)
|
print(text)
|
||||||
if args.prompt:
|
if args.prompt:
|
||||||
break
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user