Merge pull request #1840 from huggingface/generation_sampler
[WIP] Sampling sequence generator for transformers
This commit is contained in:
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
|
|
||||||
|
|
||||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||||
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
|
|||||||
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
level = logging.INFO)
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
|
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
|
||||||
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
|
||||||
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||||
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
|
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
@@ -75,81 +71,79 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
#
|
||||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
# Functions to prepare models' input
|
||||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
#
|
||||||
Args:
|
|
||||||
logits: logits distribution shape (batch size x vocabulary size)
|
|
||||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
|
||||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
|
||||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
|
||||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
|
||||||
"""
|
|
||||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
|
||||||
if top_k > 0:
|
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
if top_p > 0.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
||||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
|
|
||||||
# Remove tokens with cumulative probability above the threshold
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
# Shift the indices to the right to keep also the first token above the threshold
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
|
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
|
if args.temperature > 0.7:
|
||||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
logger.info(
|
||||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
"CTRL typically works better with lower temperatures (and lower top_k)."
|
||||||
generated = context
|
)
|
||||||
with torch.no_grad():
|
|
||||||
for _ in trange(length):
|
|
||||||
|
|
||||||
inputs = {'input_ids': generated}
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||||
if is_xlnet:
|
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
|
||||||
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
logger.info(
|
||||||
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
||||||
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
)
|
||||||
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
return prompt_text
|
||||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
|
||||||
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
|
||||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
|
||||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
|
||||||
|
|
||||||
if is_xlm_mlm and xlm_mask_token:
|
|
||||||
# XLM MLM models are direct models (predict same token, not next token)
|
|
||||||
# => need one additional dummy token in the input (will be masked and guessed)
|
|
||||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
|
||||||
inputs = {'input_ids': input_ids}
|
|
||||||
|
|
||||||
if xlm_lang is not None:
|
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
# kwargs = {"language": None, "mask_token_id": None}
|
||||||
|
|
||||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
# Set the language
|
||||||
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
|
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||||
|
if hasattr(model.config, "lang2id") and use_lang_emb:
|
||||||
|
available_languages = model.config.lang2id.keys()
|
||||||
|
if args.xlm_language in available_languages:
|
||||||
|
language = args.xlm_language
|
||||||
|
else:
|
||||||
|
language = None
|
||||||
|
while language not in available_languages:
|
||||||
|
language = input(
|
||||||
|
"Using XLM. Select language in "
|
||||||
|
+ str(list(available_languages))
|
||||||
|
+ " >>> "
|
||||||
|
)
|
||||||
|
# kwargs["language"] = tokenizer.lang2id[language]
|
||||||
|
|
||||||
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||||
for i in range(num_samples):
|
# XLM masked-language modeling (MLM) models need masked token
|
||||||
for _ in set(generated[i].tolist()):
|
# is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||||
next_token_logits[i, _] /= repetition_penalty
|
# if is_xlm_mlm:
|
||||||
|
# kwargs["mask_token_id"] = tokenizer.mask_token_id
|
||||||
|
|
||||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
return prompt_text
|
||||||
if temperature == 0: # greedy sampling:
|
|
||||||
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
|
|
||||||
else:
|
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||||
generated = torch.cat((generated, next_token), dim=1)
|
return prompt_text, {}
|
||||||
return generated
|
|
||||||
|
|
||||||
|
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||||
|
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||||
|
return prompt_text, {}
|
||||||
|
|
||||||
|
|
||||||
|
PREPROCESSING_FUNCTIONS = {
|
||||||
|
"ctrl": prepare_ctrl_input,
|
||||||
|
"xlm": prepare_xlm_input,
|
||||||
|
"xlnet": prepare_xlnet_input,
|
||||||
|
"transfo-xl": prepare_transfoxl_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_length_to_model(length, max_sequence_length):
|
||||||
|
if length < 0 and max_sequence_length > 0:
|
||||||
|
length = max_sequence_length
|
||||||
|
elif 0 < max_sequence_length < length:
|
||||||
|
length = max_sequence_length # No generation bigger than model size
|
||||||
|
elif length < 0:
|
||||||
|
length = MAX_LENGTH # avoid infinite loop
|
||||||
|
return length
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -157,108 +151,76 @@ def main():
|
|||||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
|
|
||||||
parser.add_argument("--prompt", type=str, default="")
|
parser.add_argument("--prompt", type=str, default="")
|
||||||
parser.add_argument("--padding_text", type=str, default="")
|
|
||||||
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
|
||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
parser.add_argument("--num_samples", type=int, default=1)
|
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||||
parser.add_argument("--temperature", type=float, default=1.0,
|
|
||||||
help="temperature of 0 implies greedy sampling")
|
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
|
||||||
parser.add_argument("--repetition_penalty", type=float, default=1.0,
|
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
||||||
help="primarily useful for CTRL model; in that case, use 1.2")
|
parser.add_argument("--k", type=int, default=0)
|
||||||
parser.add_argument("--top_k", type=int, default=0)
|
parser.add_argument("--p", type=float, default=0.9)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
|
||||||
help="Avoid using CUDA when available")
|
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
|
||||||
help="random seed for initialization")
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||||
parser.add_argument('--stop_token', type=str, default=None,
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||||
help="Token at which text generation is stopped")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
args.device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||||
|
)
|
||||||
args.n_gpu = torch.cuda.device_count()
|
args.n_gpu = torch.cuda.device_count()
|
||||||
|
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
# Initialize the model and tokenizer
|
||||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
try:
|
||||||
|
args.model_type = args.model_type.lower()
|
||||||
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path)
|
model = model_class.from_pretrained(args.model_name_or_path)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if args.length < 0 and model.config.max_position_embeddings > 0:
|
|
||||||
args.length = model.config.max_position_embeddings
|
|
||||||
elif 0 < model.config.max_position_embeddings < args.length:
|
|
||||||
args.length = model.config.max_position_embeddings # No generation bigger than model size
|
|
||||||
elif args.length < 0:
|
|
||||||
args.length = MAX_LENGTH # avoid infinite loop
|
|
||||||
|
|
||||||
|
args.length = adjust_length_to_model(
|
||||||
|
args.length, max_sequence_length=model.config.max_position_embeddings
|
||||||
|
)
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
if args.model_type in ["ctrl"]:
|
|
||||||
if args.temperature > 0.7:
|
|
||||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
|
||||||
|
|
||||||
while True:
|
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||||
xlm_lang = None
|
|
||||||
# XLM Language usage detailed in the issues #1414
|
|
||||||
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
|
|
||||||
and model.config.use_lang_emb:
|
|
||||||
if args.xlm_lang:
|
|
||||||
language = args.xlm_lang
|
|
||||||
else:
|
|
||||||
language = None
|
|
||||||
while language not in tokenizer.lang2id.keys():
|
|
||||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
|
||||||
xlm_lang = tokenizer.lang2id[language]
|
|
||||||
|
|
||||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
# Different models need different input formatting and/or extra arguments
|
||||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||||
if is_xlm_mlm:
|
if requires_preprocessing:
|
||||||
xlm_mask_token = tokenizer.mask_token_id
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
else:
|
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
xlm_mask_token = None
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
|
||||||
|
|
||||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
output_sequences = model.generate(
|
||||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
input_ids=encoded_prompt,
|
||||||
# Models with memory likes to have a long prompt for short inputs.
|
max_length=args.length,
|
||||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
temperature=args.temperature,
|
||||||
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
|
top_k=args.k,
|
||||||
if args.model_type == "ctrl":
|
top_p=args.p,
|
||||||
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
repetition_penalty=args.repetition_penalty,
|
||||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
)
|
||||||
out = sample_sequence(
|
|
||||||
model=model,
|
|
||||||
context=context_tokens,
|
|
||||||
num_samples=args.num_samples,
|
|
||||||
length=args.length,
|
|
||||||
temperature=args.temperature,
|
|
||||||
top_k=args.top_k,
|
|
||||||
top_p=args.top_p,
|
|
||||||
repetition_penalty=args.repetition_penalty,
|
|
||||||
is_xlnet=bool(args.model_type == "xlnet"),
|
|
||||||
is_xlm_mlm=is_xlm_mlm,
|
|
||||||
xlm_mask_token=xlm_mask_token,
|
|
||||||
xlm_lang=xlm_lang,
|
|
||||||
device=args.device,
|
|
||||||
)
|
|
||||||
out = out[:, len(context_tokens):].tolist()
|
|
||||||
for o in out:
|
|
||||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
|
||||||
if args.stop_token:
|
|
||||||
index = text.find(args.stop_token)
|
|
||||||
if index == -1:
|
|
||||||
index = None
|
|
||||||
text = text[:index]
|
|
||||||
|
|
||||||
print(text)
|
# Batch size == 1. to add more examples please use num_return_sequences > 1
|
||||||
|
generated_sequence = output_sequences[0].tolist()
|
||||||
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||||
|
text = text[: t.find(args.stop_token) if args.stop_token else None]
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
|
||||||
if args.prompt:
|
|
||||||
break
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -56,8 +56,24 @@ class PretrainedConfig(object):
|
|||||||
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
|
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
|
||||||
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
||||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||||
|
|
||||||
|
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||||
self.is_decoder = kwargs.pop('is_decoder', False)
|
self.is_decoder = kwargs.pop('is_decoder', False)
|
||||||
|
|
||||||
|
# Parameters for sequence generation
|
||||||
|
self.max_length = kwargs.pop('max_length', 20)
|
||||||
|
self.do_sample = kwargs.pop('do_sample', False)
|
||||||
|
self.num_beams = kwargs.pop('num_beams', 1)
|
||||||
|
self.temperature = kwargs.pop('temperature', 1.0)
|
||||||
|
self.top_k = kwargs.pop('top_k', 50)
|
||||||
|
self.top_p = kwargs.pop('top_p', 1.0)
|
||||||
|
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
|
||||||
|
self.bos_token_id = kwargs.pop('bos_token_id', 0)
|
||||||
|
self.pad_token_id = kwargs.pop('pad_token_id', 0)
|
||||||
|
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
|
||||||
|
self.length_penalty = kwargs.pop('length_penalty', 1.)
|
||||||
|
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
|
||||||
|
|
||||||
# Fine-tuning task arguments
|
# Fine-tuning task arguments
|
||||||
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
||||||
self.num_labels = kwargs.pop('num_labels', 2)
|
self.num_labels = kwargs.pop('num_labels', 2)
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
summary_first_dropout=0.1,
|
summary_first_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
|
mask_token_id=0,
|
||||||
|
lang_id=0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
"""
|
"""
|
||||||
@@ -143,6 +145,8 @@ class XLMConfig(PretrainedConfig):
|
|||||||
self.summary_first_dropout = summary_first_dropout
|
self.summary_first_dropout = summary_first_dropout
|
||||||
self.start_n_top = start_n_top
|
self.start_n_top = start_n_top
|
||||||
self.end_n_top = end_n_top
|
self.end_n_top = end_n_top
|
||||||
|
self.mask_token_id = mask_token_id
|
||||||
|
self.lang_id = lang_id
|
||||||
|
|
||||||
if "n_words" in kwargs:
|
if "n_words" in kwargs:
|
||||||
self.n_words = kwargs["n_words"]
|
self.n_words = kwargs["n_words"]
|
||||||
|
|||||||
@@ -18,9 +18,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||||
|
|
||||||
@@ -119,8 +121,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
kwargs_common = {
|
kwargs_common = {
|
||||||
argument: value
|
argument: value
|
||||||
for argument, value in kwargs.items()
|
for argument, value in kwargs.items()
|
||||||
if not argument.startswith("encoder_")
|
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||||
and not argument.startswith("decoder_")
|
|
||||||
}
|
}
|
||||||
kwargs_decoder = kwargs_common.copy()
|
kwargs_decoder = kwargs_common.copy()
|
||||||
kwargs_encoder = kwargs_common.copy()
|
kwargs_encoder = kwargs_common.copy()
|
||||||
@@ -220,32 +221,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
Indices of decoder input sequence tokens in the vocabulary.
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments.
|
kwargs: (`optional`) Remaining dictionary of keyword arguments.
|
||||||
"""
|
"""
|
||||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
|
||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
|
||||||
# that apply to the model as whole.
|
|
||||||
# We let the specific kwargs override the common ones in case of conflict.
|
|
||||||
kwargs_common = {
|
|
||||||
argument: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if not argument.startswith("encoder_")
|
|
||||||
and not argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
kwargs_decoder = kwargs_common.copy()
|
|
||||||
kwargs_encoder = kwargs_common.copy()
|
|
||||||
kwargs_encoder.update(
|
|
||||||
{
|
|
||||||
argument[len("encoder_") :]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("encoder_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
kwargs_decoder.update(
|
|
||||||
{
|
|
||||||
argument[len("decoder_") :]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
@@ -255,15 +231,47 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
# Decode
|
|
||||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
|
decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder)
|
||||||
"attention_mask", None
|
|
||||||
)
|
|
||||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_model_kwargs(**kwargs):
|
||||||
|
""" Prepare the encoder and decoder's keyword arguments.
|
||||||
|
|
||||||
|
Keyword arguments come in 3 flavors:
|
||||||
|
- encoder-specific (prefixed by `encoder_`)
|
||||||
|
- decoder-specific (prefixed by `decoder_`)
|
||||||
|
- those that apply to the model as whole.
|
||||||
|
|
||||||
|
We let the specific kwargs override the common ones in case of
|
||||||
|
conflict.
|
||||||
|
"""
|
||||||
|
kwargs_common = {
|
||||||
|
argument: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
decoder_kwargs = kwargs_common.copy()
|
||||||
|
encoder_kwargs = kwargs_common.copy()
|
||||||
|
encoder_kwargs.update(
|
||||||
|
{
|
||||||
|
argument[len("encoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("encoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
decoder_kwargs.update(
|
||||||
|
{
|
||||||
|
argument[len("decoder_") :]: value
|
||||||
|
for argument, value in kwargs.items()
|
||||||
|
if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
|
||||||
|
return encoder_kwargs, decoder_kwargs
|
||||||
|
|
||||||
|
|
||||||
class Model2Model(PreTrainedEncoderDecoder):
|
class Model2Model(PreTrainedEncoderDecoder):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
outputs = [softmax_output, None] + outputs
|
outputs = [softmax_output, None] + outputs
|
||||||
|
|
||||||
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
""" Double-check if you are using adaptive softmax.
|
||||||
|
"""
|
||||||
|
if self.sample_softmax > 0:
|
||||||
|
return self.out_layer
|
||||||
|
else:
|
||||||
|
return self.crit.out_layers[-1]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
||||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -496,6 +496,403 @@ class PreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||||
|
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||||
|
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
|
||||||
|
length_penalty=None, num_return_sequences=None):
|
||||||
|
""" Sequence generator for models with a LM head.
|
||||||
|
|
||||||
|
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
|
and beam-search.
|
||||||
|
|
||||||
|
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
|
||||||
|
|
||||||
|
Params:
|
||||||
|
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||||
|
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||||
|
it as an empty `torch.LongTensor` of shape (1,)
|
||||||
|
**max_length**: (`optional`) int
|
||||||
|
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||||
|
**do_sample**: (`optional`) bool
|
||||||
|
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
|
||||||
|
**num_beams**: (`optional`) int
|
||||||
|
Number of beams for beam search. 1 means no beam serach. Default to 1.
|
||||||
|
**temperature**: (`optional`) float
|
||||||
|
The value used to module the next token probabilities.
|
||||||
|
**top_k**: (`optional`) int
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||||
|
**top_p**: (`optional`) float
|
||||||
|
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||||
|
**repetition_penalty**: (`optional`) float
|
||||||
|
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
|
||||||
|
**bos_token_id**: (`optional`) int
|
||||||
|
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||||
|
**eos_token_ids**: (`optional`) int or list of int
|
||||||
|
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||||
|
**length_penalty**: (`optional`) int
|
||||||
|
Exponential penalty to the length. Default to 0.
|
||||||
|
**length_penalty**: (`optional`) float
|
||||||
|
Exponential penalty to the length. Default to 1.
|
||||||
|
**num_return_sequences**: (`optional`) int
|
||||||
|
The number of independantly computed returned sequences for each element in the batch. Default to 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We cannot generate if the model does not have a LM head
|
||||||
|
if self.get_output_embeddings() is None:
|
||||||
|
raise AttributeError("You tried to generate sequences with a model that does not have a LM Head."
|
||||||
|
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)")
|
||||||
|
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
|
top_k = top_k if top_k is not None else self.config.top_k
|
||||||
|
top_p = top_p if top_p is not None else self.config.top_p
|
||||||
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||||
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||||
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||||
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
if isinstance(eos_token_ids, int):
|
||||||
|
eos_token_ids = [eos_token_ids]
|
||||||
|
|
||||||
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
|
# assert temperature >= 0, "`temperature` should be positive."
|
||||||
|
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||||
|
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||||
|
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||||
|
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
|
||||||
|
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
|
||||||
|
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
|
||||||
|
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||||
|
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||||
|
assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer."
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
|
||||||
|
else:
|
||||||
|
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||||
|
|
||||||
|
# current position and vocab size
|
||||||
|
cur_len = input_ids.shape[1]
|
||||||
|
vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
|
if num_return_sequences != 1:
|
||||||
|
# Expand input to num return sequences
|
||||||
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||||
|
input_ids = input_ids.contiguous().view(batch_size * num_return_sequences, cur_len) # (batch_size * num_return_sequences, cur_len)
|
||||||
|
effective_batch_size = batch_size * num_return_sequences
|
||||||
|
else:
|
||||||
|
effective_batch_size = batch_size
|
||||||
|
|
||||||
|
if num_beams > 1:
|
||||||
|
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
|
pad_token_id, eos_token_ids, effective_batch_size,
|
||||||
|
length_penalty, num_beams, vocab_size)
|
||||||
|
else:
|
||||||
|
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||||
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
|
pad_token_id, eos_token_ids, effective_batch_size)
|
||||||
|
|
||||||
|
if num_return_sequences != 1:
|
||||||
|
output = output.view(batch_size, num_return_sequences, -1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
|
pad_token_id, eos_token_ids, batch_size):
|
||||||
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
|
All returned sequence are generated independantly.
|
||||||
|
"""
|
||||||
|
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||||
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
|
|
||||||
|
# TODO: add cached compute states
|
||||||
|
pasts = None
|
||||||
|
|
||||||
|
while cur_len < max_length:
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||||
|
outputs = self(**model_inputs)
|
||||||
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(batch_size):
|
||||||
|
for previous_tokens in set(input_ids[i].tolist()):
|
||||||
|
next_token_logits[i, previous_tokens] /= repetition_penalty
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
|
if temperature > 0 and temperature != 1.0:
|
||||||
|
next_token_logits = next_token_logits / temperature
|
||||||
|
# Top-p/top-k filtering
|
||||||
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
|
# Sample
|
||||||
|
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1)
|
||||||
|
else:
|
||||||
|
# Greedy decoding
|
||||||
|
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
|
# update generations and finished sentences
|
||||||
|
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
|
||||||
|
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||||
|
for eos_token_id in eos_token_ids:
|
||||||
|
unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long())
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
|
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||||
|
if unfinished_sents.max() == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# add eos_token_ids to unfinished sentences
|
||||||
|
if cur_len == max_length:
|
||||||
|
input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0])
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||||
|
temperature, top_k, top_p, repetition_penalty,
|
||||||
|
pad_token_id, eos_token_ids, batch_size,
|
||||||
|
length_penalty, num_beams, vocab_size):
|
||||||
|
""" Generate sequences for each example with beam search.
|
||||||
|
"""
|
||||||
|
# Expand input to num beams
|
||||||
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||||
|
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||||
|
|
||||||
|
# generated hypotheses
|
||||||
|
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)]
|
||||||
|
|
||||||
|
# scores for each sentence in the beam
|
||||||
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
|
beam_scores[:, 1:] = -1e9
|
||||||
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
|
|
||||||
|
# cache compute states
|
||||||
|
pasts = None # self.prepare_pasts()
|
||||||
|
|
||||||
|
# done sentences
|
||||||
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
|
while cur_len < max_length:
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||||
|
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
|
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(batch_size * num_beams):
|
||||||
|
for previous_tokens in set(input_ids[i].tolist()):
|
||||||
|
scores[i, previous_tokens] /= repetition_penalty
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
|
if temperature > 0 and temperature != 1.0:
|
||||||
|
scores = scores / temperature
|
||||||
|
# Top-p/top-k filtering
|
||||||
|
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
|
||||||
|
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||||
|
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
|
||||||
|
# Compute next scores
|
||||||
|
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
|
||||||
|
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
|
||||||
|
# Match shape of greedy beam search
|
||||||
|
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
||||||
|
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
||||||
|
else:
|
||||||
|
# do greedy beam search
|
||||||
|
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||||
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||||
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||||
|
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||||
|
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||||
|
next_scores, next_words = torch.topk(_scores, 2*num_beams, dim=1, largest=True, sorted=True)
|
||||||
|
|
||||||
|
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||||
|
|
||||||
|
# next batch beam content
|
||||||
|
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
||||||
|
next_batch_beam = []
|
||||||
|
|
||||||
|
# for each sentence
|
||||||
|
for batch_ex in range(batch_size):
|
||||||
|
|
||||||
|
# if we are done with this sentence
|
||||||
|
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
|
||||||
|
if done[batch_ex]:
|
||||||
|
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||||
|
continue
|
||||||
|
|
||||||
|
# next sentence beam content
|
||||||
|
next_sent_beam = []
|
||||||
|
|
||||||
|
# next words for this sentence
|
||||||
|
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
|
||||||
|
|
||||||
|
# get beam and word IDs
|
||||||
|
beam_id = idx // vocab_size
|
||||||
|
word_id = idx % vocab_size
|
||||||
|
|
||||||
|
# end of sentence, or next word
|
||||||
|
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
|
||||||
|
generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item())
|
||||||
|
else:
|
||||||
|
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id))
|
||||||
|
|
||||||
|
# the beam for next step is full
|
||||||
|
if len(next_sent_beam) == num_beams:
|
||||||
|
break
|
||||||
|
|
||||||
|
# update next beam content
|
||||||
|
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams
|
||||||
|
if len(next_sent_beam) == 0:
|
||||||
|
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
|
||||||
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
assert len(next_batch_beam) == num_beams * (batch_ex + 1)
|
||||||
|
|
||||||
|
# sanity check / prepare next batch
|
||||||
|
assert len(next_batch_beam) == batch_size * num_beams
|
||||||
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
|
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||||
|
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||||
|
|
||||||
|
# re-order batch and internal states
|
||||||
|
input_ids = input_ids[beam_idx, :]
|
||||||
|
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||||
|
# TODO: Activate cache
|
||||||
|
# for k in cache.keys():
|
||||||
|
# if k != 'slen':
|
||||||
|
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
|
||||||
|
|
||||||
|
# update current length
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
|
# stop when we are done with each sentence
|
||||||
|
if all(done):
|
||||||
|
break
|
||||||
|
|
||||||
|
# visualize hypotheses
|
||||||
|
# print([len(x) for x in generated_hyps], cur_len)
|
||||||
|
# globals().update( locals() );
|
||||||
|
# !import code; code.interact(local=vars())
|
||||||
|
# for ii in range(batch_size):
|
||||||
|
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
|
||||||
|
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
|
||||||
|
# print("")
|
||||||
|
|
||||||
|
# select the best hypotheses
|
||||||
|
tgt_len = input_ids.new(batch_size)
|
||||||
|
best = []
|
||||||
|
|
||||||
|
for i, hypotheses in enumerate(generated_hyps):
|
||||||
|
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||||
|
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
|
||||||
|
best.append(best_hyp)
|
||||||
|
|
||||||
|
# generate target batch
|
||||||
|
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||||
|
for i, hypo in enumerate(best):
|
||||||
|
decoded[i, :tgt_len[i] - 1] = hypo
|
||||||
|
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1):
|
||||||
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
|
Args:
|
||||||
|
logits: logits distribution shape (batch size, vocabulary size)
|
||||||
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||||
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
|
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
|
"""
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
if min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class BeamHypotheses(object):
|
||||||
|
|
||||||
|
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
|
||||||
|
"""
|
||||||
|
Initialize n-best list of hypotheses.
|
||||||
|
"""
|
||||||
|
self.max_length = max_length - 1 # ignoring bos_token
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
self.early_stopping = early_stopping
|
||||||
|
self.n_hyp = n_hyp
|
||||||
|
self.hyp = []
|
||||||
|
self.worst_score = 1e9
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Number of hypotheses in the list.
|
||||||
|
"""
|
||||||
|
return len(self.hyp)
|
||||||
|
|
||||||
|
def add(self, hyp, sum_logprobs):
|
||||||
|
"""
|
||||||
|
Add a new hypothesis to the list.
|
||||||
|
"""
|
||||||
|
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||||
|
if len(self) < self.n_hyp or score > self.worst_score:
|
||||||
|
self.hyp.append((score, hyp))
|
||||||
|
if len(self) > self.n_hyp:
|
||||||
|
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||||
|
del self.hyp[sorted_scores[0][1]]
|
||||||
|
self.worst_score = sorted_scores[1][0]
|
||||||
|
else:
|
||||||
|
self.worst_score = min(score, self.worst_score)
|
||||||
|
|
||||||
|
def is_done(self, best_sum_logprobs):
|
||||||
|
"""
|
||||||
|
If there are enough hypotheses and that none of the hypotheses being generated
|
||||||
|
can become better than the worst one in the heap, then we are done with this sentence.
|
||||||
|
"""
|
||||||
|
if len(self) < self.n_hyp:
|
||||||
|
return False
|
||||||
|
elif self.early_stopping:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
def __init__(self, nf, nx):
|
def __init__(self, nf, nx):
|
||||||
|
|||||||
@@ -649,6 +649,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
|
mask_token_id = self.config.mask_token_id
|
||||||
|
lang_id = self.config.lang_id
|
||||||
|
|
||||||
|
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
|
||||||
|
input_ids = torch.cat([input_ids, mask_token], dim=1)
|
||||||
|
if lang_id is not None:
|
||||||
|
langs = torch.full_like(input_ids, lang_id)
|
||||||
|
else:
|
||||||
|
langs = None
|
||||||
|
return {"input_ids": input_ids, "langs": langs}
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
|
|||||||
@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||||
|
# Add dummy token at the end (no attention on this one)
|
||||||
|
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
|
||||||
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||||
|
|
||||||
|
# Build permutation mask so that previous tokens don't see last token
|
||||||
|
perm_mask = torch.zeros(
|
||||||
|
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
|
||||||
|
dtype=torch.float, device=input_ids.device
|
||||||
|
)
|
||||||
|
perm_mask[:, :, -1] = 1.0
|
||||||
|
|
||||||
|
# We'll only predict the last token
|
||||||
|
target_mapping = torch.zeros(
|
||||||
|
(input_ids.shape[0], 1, input_ids.shape[1]),
|
||||||
|
dtype=torch.float, device=input_ids.device
|
||||||
|
)
|
||||||
|
target_mapping[0, 0, -1] = 1.0
|
||||||
|
|
||||||
|
return {"input_ids": input_ids,
|
||||||
|
"perm_mask": perm_mask,
|
||||||
|
"target_mapping": target_mapping
|
||||||
|
}
|
||||||
|
|
||||||
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
transformer_outputs = self.transformer(input_ids,
|
transformer_outputs = self.transformer(input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user