From 07bc8efbc30f88e25d78b66811d670584a1bb97b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 15 Nov 2019 10:51:38 +0100 Subject: [PATCH 01/11] add greedy decoding and sampling --- examples/run_generation.py | 292 ++++++++++------------- transformers/modeling_encoder_decoder.py | 162 ++++++++++--- transformers/modeling_transfo_xl.py | 10 +- transformers/modeling_utils.py | 229 ++++++++++++++++++ transformers/modeling_xlm.py | 29 ++- transformers/modeling_xlnet.py | 34 +++ transformers/tests/sampling_test.py | 213 +++++++++++++++++ 7 files changed, 776 insertions(+), 193 deletions(-) create mode 100644 transformers/tests/sampling_test.py diff --git a/examples/run_generation.py b/examples/run_generation.py index 2d917660cf..2075ad8457 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import argparse import logging -from tqdm import trange import torch -import torch.nn.functional as F import numpy as np -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig - from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer @@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer from transformers import XLMWithLMHeadModel, XLMTokenizer -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) logger = logging.getLogger(__name__) MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ()) - MODEL_CLASSES = { - 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), - 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), - 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), - 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), - 'xlm': (XLMWithLMHeadModel, XLMTokenizer), + "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), + "ctrl": (CTRLLMHeadModel, CTRLTokenizer), + "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + "xlnet": (XLNetLMHeadModel, XLNetTokenizer), + "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), + "xlm": (XLMWithLMHeadModel, XLMTokenizer), } # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -75,81 +71,78 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) - -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size x vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - top_k = min(top_k, logits.size(-1)) # Safety check - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits +# +# Functions to prepare models' input +# -def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, - is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'): - context = torch.tensor(context, dtype=torch.long, device=device) - context = context.unsqueeze(0).repeat(num_samples, 1) - generated = context - with torch.no_grad(): - for _ in trange(length): +def prepare_ctrl_input(args, _, tokenizer, prompt_text): + if args.temperature > 0.7: + logger.info( + "CTRL typically works better with lower temperatures (and lower top_k)." + ) - inputs = {'input_ids': generated} - if is_xlnet: - # XLNet is a direct (predict same token, not next token) and bi-directional model by default - # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring) - input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1) - perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device) - perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token - target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device) - target_mapping[0, 0, -1] = 1.0 # predict last token - inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} + encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) + if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): + logger.info( + "WARNING! You are not starting your generation from a control code so you won't get good results" + ) + return prompt_text, {} - if is_xlm_mlm and xlm_mask_token: - # XLM MLM models are direct models (predict same token, not next token) - # => need one additional dummy token in the input (will be masked and guessed) - input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) - inputs = {'input_ids': input_ids} - if xlm_lang is not None: - inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) +def prepare_xlm_input(args, model, tokenizer, prompt_text): + kwargs = {"language": None, "mask_token": None} - outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) - next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) + # Set the language + use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb + if hasattr(model.config, "lang2id") and use_lang_emb: + available_languages = model.config.lang2id.keys() + if args.xlm_language in available_languages: + language = args.xlm_language + else: + language = None + while language not in available_languages: + language = input( + "Using XLM. Select language in " + + str(list(available_languages)) + + " >>> " + ) + kwargs["language"] = tokenizer.lang2id[language] - # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) - for i in range(num_samples): - for _ in set(generated[i].tolist()): - next_token_logits[i, _] /= repetition_penalty - - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - if temperature == 0: # greedy sampling: - next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) - else: - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - generated = torch.cat((generated, next_token), dim=1) - return generated + # XLM masked-language modeling (MLM) models need masked token + is_xlm_mlm = "mlm" in args.model_name_or_path + if is_xlm_mlm: + kwargs["mask_token"] = tokenizer.mask_token_id + + return prompt_text, kwargs + + +def prepare_xlnet_input(args, _, tokenizer, prompt_text): + prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + return prompt_text, {} + + +def prepare_transfoxl_input(args, _, tokenizer, prompt_text): + prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + return prompt_text, {} + + +PREPROCESSING_FUNCTIONS = { + "ctrl": prepare_ctrl_input, + "xlm": prepare_xlm_input, + "xlnet": prepare_xlnet_input, + "transfo-xl": prepare_transfoxl_input, +} + + +def adjust_length_to_model(length, max_sequence_length): + if length < 0 and max_sequence_length > 0: + length = max_sequence_length + elif 0 < max_sequence_length < length: + length = max_sequence_length # No generation bigger than model size + elif length < 0: + length = MAX_LENGTH # avoid infinite loop + return length def main(): @@ -157,104 +150,81 @@ def main(): parser.add_argument("--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--padding_text", type=str, default="") - parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=20) - parser.add_argument("--num_samples", type=int, default=1) - parser.add_argument("--temperature", type=float, default=1.0, - help="temperature of 0 implies greedy sampling") - parser.add_argument("--repetition_penalty", type=float, default=1.0, - help="primarily useful for CTRL model; in that case, use 1.2") - parser.add_argument("--top_k", type=int, default=0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--no_cuda", action='store_true', - help="Avoid using CUDA when available") - parser.add_argument('--seed', type=int, default=42, - help="random seed for initialization") - parser.add_argument('--stop_token', type=str, default=None, - help="Token at which text generation is stopped") + parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") + + parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling") + parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--p", type=float, default=0.9) + + parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.") + parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") + + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") args = parser.parse_args() - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) args.n_gpu = torch.cuda.device_count() set_seed(args) - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + # Initialize the model and tokenizer + try: + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + except KeyError as ke: + raise ke( + "the model {} you specified is not supported. You are welcome to add it and open a PR :)" + ) + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) model.eval() - if args.length < 0 and model.config.max_position_embeddings > 0: - args.length = model.config.max_position_embeddings - elif 0 < model.config.max_position_embeddings < args.length: - args.length = model.config.max_position_embeddings # No generation bigger than model size - elif args.length < 0: - args.length = MAX_LENGTH # avoid infinite loop - + args.length = adjust_length_to_model( + args.length, max_sequence_length=model.config.max_position_embeddings + ) logger.info(args) - if args.model_type in ["ctrl"]: - if args.temperature > 0.7: - logger.info('CTRL typically works better with lower temperatures (and lower top_k).') - while True: - xlm_lang = None - # XLM Language usage detailed in the issues #1414 - if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \ - and model.config.use_lang_emb: - if args.xlm_lang: - language = args.xlm_lang - else: - language = None - while language not in tokenizer.lang2id.keys(): - language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") - xlm_lang = tokenizer.lang2id[language] + prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") - # XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence) - is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path - if is_xlm_mlm: - xlm_mask_token = tokenizer.mask_token_id - else: - xlm_mask_token = None + # Different models need different input formatting and/or extra arguments + requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() + model_kwargs = {} + if requires_preprocessing: + prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) + prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) + encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0) - raw_text = args.prompt if args.prompt else input("Model prompt >>> ") - if args.model_type in ["transfo-xl", "xlnet"]: - # Models with memory likes to have a long prompt for short inputs. - raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text - context_tokens = tokenizer.encode(raw_text, add_special_tokens=False) - if args.model_type == "ctrl": - if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): - logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") - out = sample_sequence( - model=model, - context=context_tokens, - num_samples=args.num_samples, - length=args.length, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - repetition_penalty=args.repetition_penalty, - is_xlnet=bool(args.model_type == "xlnet"), - is_xlm_mlm=is_xlm_mlm, - xlm_mask_token=xlm_mask_token, - xlm_lang=xlm_lang, - device=args.device, - ) - out = out[:, len(context_tokens):].tolist() - for o in out: - text = tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] + output_sequences = model.decode( + prompt_ids=encoded_prompt, + length=args.length, + temperature=args.temperature, + k=args.k, + p=args.p, + repetition_penalty=args.repetition_penalty, + device=args.device, + **model_kwargs, + ) - print(text) + generated_sequence = output_sequences.tolist()[ + encoded_prompt.size(1) : + ] # adapted to case where num_samples > 1 + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) + text = text[: text.find(args.stop_token) if args.stop_token else None] + + print(text) - if args.prompt: - break return text -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index a884abd0a2..3d8c812c2f 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera import logging import os +import warnings import torch from torch import nn +from tqdm import trange from .modeling_auto import AutoModel, AutoModelWithLMHead +from .modeling_utils import Sampler logger = logging.getLogger(__name__) @@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module): kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } kwargs_decoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy() @@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module): Indices of decoder input sequence tokens in the vocabulary. kwargs: (`optional`) Remaining dictionary of keyword arguments. """ - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. + kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) + + # Encode if needed (training, first prediction pass) + encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) + if encoder_hidden_states is None: + encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) + encoder_hidden_states = encoder_outputs[0] + else: + encoder_outputs = () + + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder) + + return decoder_outputs + encoder_outputs + + def decode( + self, + encoder_input_ids, + decoder_prompt_ids=None, + device=torch.device("cpu"), + length=10, + do_sample=False, + temperature=1.0, + k=9, + p=0., + repetition_penalty=1., + **kwargs + ): + """ Generic sequence generator for encoder-decoder models. + + For encoder-decoders the generation consists in: + - Performing a forward pass through the encoder once; + - Pass the encoder's hidden states to a decoding mechanism that + repeatedly calls the decoder to generate sequences. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) + The sequence to encode. + **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **device**: (`optional`) `torch.device` + The device on which the prompt_ids will be initialized if not provided. + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + if decoder_prompt_ids is None: + decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) + + # When the model does not have a LM head `get_output_embeddings` + # returns `None`. We use this mechanism to determine whether we + # should proceed with decoding or not. + if self.decoder.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.") + + # The followings checks that the decoder is on the same device as the one + # that is specified. It only works for models that fit on one GPU. + decoder_device = next(self.decoder.parameters()).device + if decoder_device != decoder_prompt_ids.device: + warnings.warn( + "The decoder is not on the same device as the prompt. Expected {}, got {}.".format( + decoder_prompt_ids.device, decoder_device + ) + ) + + kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) + with torch.no_grad(): + encoder_outputs = self.encoder(encoder_input_ids, **kwargs) + encoder_hidden_states = encoder_outputs[0] + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + return self._greedy_decode_or_sample( + decoder_prompt_ids, length, sampler_config, **kwargs_decoder + ) + + def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder): + sampler = Sampler(**sampler_config) + with torch.no_grad(): + generated_sequence = prompt_ids + for _ in trange(length): + arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder) + outputs = self.decoder(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + @staticmethod + def prepare_model_kwargs(**kwargs): + """ Prepare the encoder and decoder's keyword arguments. + + Keyword arguments come in 3 flavors: + - encoder-specific (prefixed by `encoder_`) + - decoder-specific (prefixed by `decoder_`) + - those that apply to the model as whole. + + We let the specific kwargs override the common ones in case of + conflict. + """ kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } - kwargs_decoder = kwargs_common.copy() - kwargs_encoder = kwargs_common.copy() - kwargs_encoder.update( + decoder_kwargs = kwargs_common.copy() + encoder_kwargs = kwargs_common.copy() + encoder_kwargs.update( { argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") } ) - kwargs_decoder.update( + decoder_kwargs.update( { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } ) + decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None) - # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - if encoder_hidden_states is None: - encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[ - 0 - ] # output the last layer hidden state - else: - encoder_outputs = () - - # Decode - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get( - "attention_mask", None - ) - decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) - - return decoder_outputs + encoder_outputs + return encoder_kwargs, decoder_kwargs class Model2Model(PreTrainedEncoderDecoder): diff --git a/transformers/modeling_transfo_xl.py b/transformers/modeling_transfo_xl.py index a6a82f0dfe..473d07f733 100644 --- a/transformers/modeling_transfo_xl.py +++ b/transformers/modeling_transfo_xl.py @@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary from .configuration_transfo_xl import TransfoXLConfig -from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits +from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): outputs = [softmax_output, None] + outputs return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) + + def get_output_embeddings(self): + """ Double-check if you are using adaptive softmax. + """ + if self.sample_softmax > 0: + return self.out_layer + else: + return self.crit.out_layers[-1] diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 398172a88c..74038351fd 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -23,12 +23,14 @@ import json import logging import os from io import open +import warnings import six import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from tqdm import trange from .configuration_utils import PretrainedConfig from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME @@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module): def base_model(self): return getattr(self, self.base_model_prefix, self) + def decode(self, + prompt_ids=None, + device=torch.device('cpu'), + length=10, + do_sample=False, + temperature=1., + k=9, + p=0, + repetition_penalty=1, + **model_kwargs): + """ Generic sequence generator for single-stack models with a LM head. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) + The sequence to encode. + **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **device**: (`optional`) `torch.device` + The device on which the prompt_ids will be initialized if not provided. + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + + if prompt_ids is None: + prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) + + # When the model does not have a LM head `get_output_embeddings` + # returns `None`. We use this mechanism to determine whether we + # should proceed with decoding or not. + if self.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + + # The followings checks that the model is on the same device as the one + # that is specified. It only works for models that fit on one GPU. + model_device = next(self.parameters()).device + if model_device != prompt_ids.device: + warnings.warn( + "The model is not on the same device as the prompts. Expected {}, got {}.".format( + prompt_ids.device, model_device + ) + ) + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs) + + def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs): + """ Generate text using greedy decoding or by sampling tokens.""" + sampler = Sampler(**sampler_config) + generated_sequence = prompt_ids + with torch.no_grad(): + for _ in trange(length): + arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) + outputs = self(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token( + next_tokens_logits, generated_sequence + ) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + def _prepare_inputs_for_decoding(self, input_ids, **kwargs): + arguments = {"input_ids": input_ids} + arguments.update(kwargs) + return arguments + def get_input_embeddings(self): """ Get model's input embeddings """ @@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) + + +class Sampler(object): + r""" Sampler is used to generate sequences of ids from logit inputs. + + Greedy decoding, which consists in chosing the most probable token at each + step, is the default behaviour. Sampling with varying temperature, top_k + and nucleus filtering is also implemented. + + Attributes: + **device**: ``torch.device`` + Device on which the computations will be run. + **do_sample**: bool + Whether to sample or do greedy decoding. + **k**: int between 0 and vocab_size + Parameter for the top-k filtering + **p**: float between 0 and 1 + Parameter for the nucleus filtering + **temperature**: strictly positive float + Parameter used to modulate the distribution over ids. Low temperatures + put more emphasis on highly probably token while high temperatures tend + to smooth the probability distribution. + **repetition_penalty**: strictly postitive float + The penalty applied to repeating ids + """ + + def __init__( + self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 + ): + self.k = k + self.p = p + self.do_sample = do_sample + self.temperature = temperature + self.repetition_penalty = repetition_penalty + + self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False + + if self.p > 1: + warnings.warn( + """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). + However p is a probability and its value must lie between 0 and 1. In effect, no filtering + will be applied. If this is not the behavior you expect, change the value of p.""".format( + self.p + ) + ) + + def get_one_token(self, next_token_logits, past_sequence): + logits = self.apply_repetition_penalty(next_token_logits, past_sequence) + if self.do_sample: + logits = self.apply_temperature(logits) + logits = self.apply_top_k_filter(logits) + logits = self.apply_nucleus_filter(logits) + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return torch.argmax(logits, dim=-1).unsqueeze(-1) + + def apply_repetition_penalty(self, logits, past_sequence): + """ Apply a penalty to tokens that appear more than once in the + generated sequence. + + .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer + language model for controllable generation." arXiv preprint + arXiv:1909.05858 (2019). + """ + if self.do_apply_repetition_penalty: + generated_token_idx = set(past_sequence[0].tolist()) + for token_idx in generated_token_idx: + logits[0, token_idx] /= self.repetition_penalty + return logits + + def apply_temperature(self, logits): + """ Shape the tokens' distribution through temperature. The higher the value + of the temperature, the more skewed towards high probability events the + distribution is. + + .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. + MIT press, 2016. + """ + # when dividing a float by 0, torch returns inf which in turns breaks the + # multinomial with an error message that is not very helpful. It is better + # for the user to break the execution and explain why. + if self.temperature == 0: + raise ZeroDivisionError( + """You are trying to sample with a temperature equal to 0. + If you wanted to do greedy sampling, set instead `do_sample` to False. + Otherwise set the temperature to a value different from 0.""" + ) + return logits / self.temperature + + def apply_top_k_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically we select the set of size k such that + the sum of its items' probabilities is maximum. + + .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural + story generation." arXiv preprint arXiv:1805.04833 (2018). + """ + if self.k > 0: + vocabulary_size = logits.size(-1) + if self.k > vocabulary_size: + warnings.warn( + """You provided a value for k ({}) that is larger than the vocabulary size ({}). + We adjusted k's value to the vocabulary size; if that was what you intended to do + we recommend setting k to 0 instead. It this is not the behavior you expected, + choose a value of k that is smaller than the vocabulary size.""".format( + self.k, vocabulary_size + ) + ) + self.k = vocabulary_size + + indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + return logits + + def apply_nucleus_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically, choose the smallest set such that the + sum of its items' probabilities is greater than a number p in [0,1]. + + .. Holtzman, Ari, et al. "The curious case of neural text + degeneration." arXiv preprint arXiv:1904.09751 (2019). + """ + if self.p > 0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_probabilities = F.softmax(sorted_logits, dim=-1) + cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) + + # Remove tokens with cumulative probability above the threshold, + # but keep the first token above the threshold. + sorted_indices_to_remove = cumulative_probabilities > self.p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = -float("Inf") + + return logits diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 257f0da394..295fff7943 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -646,7 +646,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, - lengths=lengths, + lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds) @@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): return outputs + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + mask_token = model_kwargs.pop("mask_token", None) + language = model_kwargs.pop("language", None) + input_ids = self._append_mask_token(input_ids, mask_token) + langs = self._create_language_embeddings(input_ids, language) + arguments = {"input_ids": input_ids, "langs": langs} + arguments.update(model_kwargs) + + return arguments + + @staticmethod + def _append_mask_token(sequence, mask_token_id): + """ Append a [MASK] token at the end of the sequence that the MLM model + is going to try to predict. + """ + if mask_token_id is not None: + tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long) + return torch.cat((sequence, tokens_to_append), dim=1) + + return sequence + + @staticmethod + def _create_language_embeddings(sequence, language): + if language is not None: + return torch.tensor([language] * sequence.shape[1]).view(1, -1) + return None + @add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 225e5b059b..2153923dd2 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): return outputs # return (loss), logits, (mems), (hidden states), (attentions) + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + input_ids = self._add_dummy_token(input_ids) + perm_mask = self._create_perm_mask(input_ids) + target_mapping = self._create_target_mapping(input_ids) + arguments = { + "input_ids": input_ids, + "perm_mask": perm_mask, + "target_mapping": target_mapping, + } + return arguments + + @staticmethod + def _add_dummy_token(sequence): + dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long) + return torch.cat((sequence, dummy), dim=1) + + @staticmethod + def _create_perm_mask(sequence): + mask = torch.zeros( + (sequence.shape[0], sequence.shape[1], sequence.shape[1]), + dtype=torch.float, + ) + mask[:, :, -1] = 1.0 # Previous tokens don't see last token + return mask + + @staticmethod + def _create_target_mapping(sequence): + target_mapping = torch.zeros( + (sequence.shape[0], 1, sequence.shape[1]), + dtype=torch.float, + ) + target_mapping[0, 0, -1] = 1.0 # predict last token + return target_mapping + @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/tests/sampling_test.py b/transformers/tests/sampling_test.py new file mode 100644 index 0000000000..98cc23bf2b --- /dev/null +++ b/transformers/tests/sampling_test.py @@ -0,0 +1,213 @@ +# coding=utf-8 +import sys +import unittest + +import numpy as np +import pytest + +from transformers import is_torch_available + +if is_torch_available(): + import torch + + from transformers import ( + BertConfig, + BertModel, + GPT2Config, + GPT2LMHeadModel, + OpenAIGPTConfig, + OpenAIGPTLMHeadModel, + TransfoXLConfig, + TransfoXLLMHeadModel, + XLMConfig, + XLMWithLMHeadModel, + XLNetConfig, + XLNetLMHeadModel, + Model2Model, + ) + from transformers.modeling_utils import Sampler +else: + pytestmark = pytest.mark.skip("Require Torch") + + +class SamplerTest(unittest.TestCase): + def test_nucleus_sampling(self): + inf = -float("Inf") + test_cases = ( + { + "p": 0, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, 0.1, 0.2]), + }, + { + "p": 0.01, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, inf, inf]), + }, + { + "p": 1, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, 0.1, 0.2]), + }, + { + "p": 0.2, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, inf]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, 0.2]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.1, 0.7, 0.2]), + "expected": torch.tensor([inf, 0.7, 0.2]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.7, 0.2, 0.1]), + "expected": torch.tensor([0.7, 0.2, inf]), + }, + { + "p": 0.91, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + ) + for case in test_cases: + config = { + "do_sample": True, + "temperature": 1.0, + "k": 0, + "p": case["p"], + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + filtered_logits = sampler.apply_nucleus_filter(case["logits"]) + np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) + + def test_top_k_filter(self): + inf = -float("Inf") + test_cases = ( + { + "k": 0, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + { + "k": 1, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, inf]), + }, + { + "k": 2, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, 0.2]), + }, + { + "k": 3, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + ) + for case in test_cases: + config = { + "do_sample": True, + "temperature": 1.0, + "k": case["k"], + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + filtered_logits = sampler.apply_top_k_filter(case["logits"]) + np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) + + @pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2") + def test_wrong_k_value(self): + case = {"k": 10, "vocab_size": 5} + config = { + "do_sample": True, + "temperature": 1.0, + "k": case["k"], + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0) + past_sequence = torch.tensor([]) + with self.assertWarns(UserWarning): + _ = sampler.get_one_token(next_token_logits, past_sequence) + + def test_zero_temperature(self): + temperature = 0 + config = { + "do_sample": True, + "temperature": temperature, + "k": 0, + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + next_token_logits = torch.rand(10).unsqueeze(0) + past_sequence = torch.tensor([]) + with self.assertRaises(ZeroDivisionError): + _ = sampler.get_one_token(next_token_logits, past_sequence) + + +class SamplerSingleStackTest(unittest.TestCase): + def test_raises_exception_when_no_LM_head(self): + models = [BertModel(BertConfig())] + for model in models: + with self.assertRaises(AttributeError): + model.decode() + + @pytest.mark.slow + def test_forward_pass_and_output_length(self): + models = { + "XLNet": XLNetLMHeadModel(XLNetConfig()), + "XLM": XLMWithLMHeadModel(XLMConfig()), + "TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()), + "GPT2": GPT2LMHeadModel(GPT2Config()), + "GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()), + } + kwargs = { + "XLNet": {}, + "XLM": {"mask_token": 0}, + "TransfoXL": {}, + "GPT2": {}, + "GPT": {}, + } + prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) + generated_length = 5 + expected_length = 8 + + for name, model in models.items(): + kwargs_model = kwargs[name] + output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model) + self.assertEqual(len(output), expected_length) + + +class SamplerEncoderDecoderTest(unittest.TestCase): + @pytest.mark.slow + def test_forward_pass_and_output_length(self): + model = Model2Model.from_pretrained("bert-base-uncased") + + encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) + generated_length = 5 + expected_length = 8 + + output = model.decode( + encoder_input_ids, + decoder_prompt_ids=prompt, + k=2, + p=0.5, + repetition_penalty=2, + length=generated_length, + ) + self.assertEqual(len(output), expected_length) + + +if __name__ == "__main__": + unittest.main() From a468870fd27c601e3717c5b9ca691e18a8c7227f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 16 Dec 2019 22:22:30 +0100 Subject: [PATCH 02/11] refactoring generation --- transformers/configuration_utils.py | 11 + transformers/modeling_utils.py | 429 +++++++++++++--------------- 2 files changed, 213 insertions(+), 227 deletions(-) diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 08cee75d81..9c3360892d 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -57,8 +57,19 @@ class PretrainedConfig(object): self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.pruned_heads = kwargs.pop('pruned_heads', {}) + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder self.is_decoder = kwargs.pop('is_decoder', False) + # Parameters for sequence generation + self.generate_length = kwargs.pop('generate_length', 10) + self.generate_do_sample = kwargs.pop('generate_do_sample', False) + self.generate_num_beams = kwargs.pop('generate_num_beams', 1) + self.generate_temperature = kwargs.pop('generate_temperature', 1.0) + self.generate_top_k = kwargs.pop('generate_top_k', 50) + self.generate_top_p = kwargs.pop('generate_top_p', 0.0) + self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) + def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 74038351fd..27d42c552a 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -82,6 +82,7 @@ class PreTrainedModel(nn.Module): "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) + # Save config in model self.config = config @@ -89,93 +90,6 @@ class PreTrainedModel(nn.Module): def base_model(self): return getattr(self, self.base_model_prefix, self) - def decode(self, - prompt_ids=None, - device=torch.device('cpu'), - length=10, - do_sample=False, - temperature=1., - k=9, - p=0, - repetition_penalty=1, - **model_kwargs): - """ Generic sequence generator for single-stack models with a LM head. - - The method currently supports greedy decoding and sampling. See the - documentation of the `Sampler` class for more information about the - parameters related to sampling. - - Params: - **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) - The sequence to encode. - **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) - The sequence used as a prompt for the generation. If `None` the method initializes - it as an empty `torch.LongTensor` of shape (1,) - **device**: (`optional`) `torch.device` - The device on which the prompt_ids will be initialized if not provided. - **length**: (`optional`) int - The length of the sequence to be generated. - **do_sample**: (`optional`) bool - If set to `False` we use greedy decoding; otherwise sampling. - **temperature**: (`optional`) float - The value used to module the next token probabilities. - **k**: (`optional`) int - The parameter used for k-filtering. - **p**: (`optional`) float - The parameter for nucleus sampling. Must be between 0 and 1. - **repetition_penalty**: (`optional`) float - The parameter for repetition penalty. - """ - - if prompt_ids is None: - prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) - - # When the model does not have a LM head `get_output_embeddings` - # returns `None`. We use this mechanism to determine whether we - # should proceed with decoding or not. - if self.get_output_embeddings() is None: - raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") - - # The followings checks that the model is on the same device as the one - # that is specified. It only works for models that fit on one GPU. - model_device = next(self.parameters()).device - if model_device != prompt_ids.device: - warnings.warn( - "The model is not on the same device as the prompts. Expected {}, got {}.".format( - prompt_ids.device, model_device - ) - ) - - sampler_config = { - "k": k, - "p": p, - "do_sample": do_sample, - "temperature": temperature, - "repetition_penalty": repetition_penalty, - } - return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs) - - def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs): - """ Generate text using greedy decoding or by sampling tokens.""" - sampler = Sampler(**sampler_config) - generated_sequence = prompt_ids - with torch.no_grad(): - for _ in trange(length): - arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) - outputs = self(**arguments) - next_tokens_logits = outputs[0][:, -1, :] - next_tokens = sampler.get_one_token( - next_tokens_logits, generated_sequence - ) - generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) - - return generated_sequence.squeeze(0) - - def _prepare_inputs_for_decoding(self, input_ids, **kwargs): - arguments = {"input_ids": input_ids} - arguments.update(kwargs) - return arguments - def get_input_embeddings(self): """ Get model's input embeddings """ @@ -306,6 +220,9 @@ class PreTrainedModel(nn.Module): # Tie weights if needed self.tie_weights() + # Initialize decoding head if we have output embeddings + + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. @@ -571,6 +488,204 @@ class PreTrainedModel(nn.Module): return model + def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None, + temperature=None, top_k=None, top_p=None, repetition_penalty=None, + **model_kwargs): + """ Generic sequence generator for single-stack models with a LM head. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + + if input_ids is None: + input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device) + + # We cannot generate if the model does not have a LM head + if self.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + + sampler = Sampler(**sampler_config) + generated_sequence = input_ids + for _ in trange(length): + arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) + outputs = self(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token( + next_tokens_logits, generated_sequence + ) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + return model_kwargs.update({"input_ids": input_ids}) + + +class Sampler(object): + r""" Sampler is used to generate sequences of ids from logit inputs. + + Greedy decoding, which consists in chosing the most probable token at each + step, is the default behaviour. Sampling with varying temperature, top_k + and nucleus filtering is also implemented. + + Attributes: + **device**: ``torch.device`` + Device on which the computations will be run. + **do_sample**: bool + Whether to sample or do greedy decoding. + **k**: int between 0 and vocab_size + Parameter for the top-k filtering + **p**: float between 0 and 1 + Parameter for the nucleus filtering + **temperature**: strictly positive float + Parameter used to modulate the distribution over ids. Low temperatures + put more emphasis on highly probably token while high temperatures tend + to smooth the probability distribution. + **repetition_penalty**: strictly postitive float + The penalty applied to repeating ids + """ + + def __init__( + self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 + ): + self.k = k + self.p = p + self.do_sample = do_sample + self.temperature = temperature + self.repetition_penalty = repetition_penalty + + self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False + + if self.p > 1: + warnings.warn( + """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). + However p is a probability and its value must lie between 0 and 1. In effect, no filtering + will be applied. If this is not the behavior you expect, change the value of p.""".format( + self.p + ) + ) + + def get_one_token(self, next_token_logits, past_sequence): + logits = self.apply_repetition_penalty(next_token_logits, past_sequence) + if self.do_sample: + logits = self.apply_temperature(logits) + logits = self.apply_top_k_filter(logits) + logits = self.apply_nucleus_filter(logits) + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return torch.argmax(logits, dim=-1).unsqueeze(-1) + + def apply_repetition_penalty(self, logits, past_sequence): + """ Apply a penalty to tokens that appear more than once in the + generated sequence. + + .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer + language model for controllable generation." arXiv preprint + arXiv:1909.05858 (2019). + """ + if self.do_apply_repetition_penalty: + generated_token_idx = set(past_sequence[0].tolist()) + for token_idx in generated_token_idx: + logits[0, token_idx] /= self.repetition_penalty + return logits + + def apply_temperature(self, logits): + """ Shape the tokens' distribution through temperature. The higher the value + of the temperature, the more skewed towards high probability events the + distribution is. + + .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. + MIT press, 2016. + """ + # when dividing a float by 0, torch returns inf which in turns breaks the + # multinomial with an error message that is not very helpful. It is better + # for the user to break the execution and explain why. + if self.temperature == 0: + raise ZeroDivisionError( + """You are trying to sample with a temperature equal to 0. + If you wanted to do greedy sampling, set instead `do_sample` to False. + Otherwise set the temperature to a value different from 0.""" + ) + return logits / self.temperature + + def apply_top_k_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically we select the set of size k such that + the sum of its items' probabilities is maximum. + + .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural + story generation." arXiv preprint arXiv:1805.04833 (2018). + """ + if self.k > 0: + vocabulary_size = logits.size(-1) + if self.k > vocabulary_size: + warnings.warn( + """You provided a value for k ({}) that is larger than the vocabulary size ({}). + We adjusted k's value to the vocabulary size; if that was what you intended to do + we recommend setting k to 0 instead. It this is not the behavior you expected, + choose a value of k that is smaller than the vocabulary size.""".format( + self.k, vocabulary_size + ) + ) + self.k = vocabulary_size + + indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + return logits + + def apply_nucleus_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically, choose the smallest set such that the + sum of its items' probabilities is greater than a number p in [0,1]. + + .. Holtzman, Ari, et al. "The curious case of neural text + degeneration." arXiv preprint arXiv:1904.09751 (2019). + """ + if self.p > 0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_probabilities = F.softmax(sorted_logits, dim=-1) + cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) + + # Remove tokens with cumulative probability above the threshold, + # but keep the first token above the threshold. + sorted_indices_to_remove = cumulative_probabilities > self.p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = -float("Inf") + + return logits + class Conv1D(nn.Module): def __init__(self, nf, nx): @@ -948,143 +1063,3 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) - - -class Sampler(object): - r""" Sampler is used to generate sequences of ids from logit inputs. - - Greedy decoding, which consists in chosing the most probable token at each - step, is the default behaviour. Sampling with varying temperature, top_k - and nucleus filtering is also implemented. - - Attributes: - **device**: ``torch.device`` - Device on which the computations will be run. - **do_sample**: bool - Whether to sample or do greedy decoding. - **k**: int between 0 and vocab_size - Parameter for the top-k filtering - **p**: float between 0 and 1 - Parameter for the nucleus filtering - **temperature**: strictly positive float - Parameter used to modulate the distribution over ids. Low temperatures - put more emphasis on highly probably token while high temperatures tend - to smooth the probability distribution. - **repetition_penalty**: strictly postitive float - The penalty applied to repeating ids - """ - - def __init__( - self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 - ): - self.k = k - self.p = p - self.do_sample = do_sample - self.temperature = temperature - self.repetition_penalty = repetition_penalty - - self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False - - if self.p > 1: - warnings.warn( - """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). - However p is a probability and its value must lie between 0 and 1. In effect, no filtering - will be applied. If this is not the behavior you expect, change the value of p.""".format( - self.p - ) - ) - - def get_one_token(self, next_token_logits, past_sequence): - logits = self.apply_repetition_penalty(next_token_logits, past_sequence) - if self.do_sample: - logits = self.apply_temperature(logits) - logits = self.apply_top_k_filter(logits) - logits = self.apply_nucleus_filter(logits) - return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) - return torch.argmax(logits, dim=-1).unsqueeze(-1) - - def apply_repetition_penalty(self, logits, past_sequence): - """ Apply a penalty to tokens that appear more than once in the - generated sequence. - - .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer - language model for controllable generation." arXiv preprint - arXiv:1909.05858 (2019). - """ - if self.do_apply_repetition_penalty: - generated_token_idx = set(past_sequence[0].tolist()) - for token_idx in generated_token_idx: - logits[0, token_idx] /= self.repetition_penalty - return logits - - def apply_temperature(self, logits): - """ Shape the tokens' distribution through temperature. The higher the value - of the temperature, the more skewed towards high probability events the - distribution is. - - .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. - MIT press, 2016. - """ - # when dividing a float by 0, torch returns inf which in turns breaks the - # multinomial with an error message that is not very helpful. It is better - # for the user to break the execution and explain why. - if self.temperature == 0: - raise ZeroDivisionError( - """You are trying to sample with a temperature equal to 0. - If you wanted to do greedy sampling, set instead `do_sample` to False. - Otherwise set the temperature to a value different from 0.""" - ) - return logits / self.temperature - - def apply_top_k_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically we select the set of size k such that - the sum of its items' probabilities is maximum. - - .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural - story generation." arXiv preprint arXiv:1805.04833 (2018). - """ - if self.k > 0: - vocabulary_size = logits.size(-1) - if self.k > vocabulary_size: - warnings.warn( - """You provided a value for k ({}) that is larger than the vocabulary size ({}). - We adjusted k's value to the vocabulary size; if that was what you intended to do - we recommend setting k to 0 instead. It this is not the behavior you expected, - choose a value of k that is smaller than the vocabulary size.""".format( - self.k, vocabulary_size - ) - ) - self.k = vocabulary_size - - indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] - logits[indices_to_remove] = -float("Inf") - - return logits - - def apply_nucleus_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically, choose the smallest set such that the - sum of its items' probabilities is greater than a number p in [0,1]. - - .. Holtzman, Ari, et al. "The curious case of neural text - degeneration." arXiv preprint arXiv:1904.09751 (2019). - """ - if self.p > 0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - sorted_probabilities = F.softmax(sorted_logits, dim=-1) - cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) - - # Remove tokens with cumulative probability above the threshold, - # but keep the first token above the threshold. - sorted_indices_to_remove = cumulative_probabilities > self.p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - dim=-1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits[indices_to_remove] = -float("Inf") - - return logits From b6938916ac7f00cd260e70d54b252909c40bced6 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 17 Dec 2019 17:23:36 +0100 Subject: [PATCH 03/11] adding beam search --- transformers/configuration_utils.py | 9 +- transformers/modeling_utils.py | 267 ++++++++++++++++++++++++---- 2 files changed, 235 insertions(+), 41 deletions(-) diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 9c3360892d..8c3e0a9f9c 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -62,13 +62,18 @@ class PretrainedConfig(object): self.is_decoder = kwargs.pop('is_decoder', False) # Parameters for sequence generation - self.generate_length = kwargs.pop('generate_length', 10) + self.generate_max_length = kwargs.pop('generate_max_length', 20) self.generate_do_sample = kwargs.pop('generate_do_sample', False) self.generate_num_beams = kwargs.pop('generate_num_beams', 1) self.generate_temperature = kwargs.pop('generate_temperature', 1.0) self.generate_top_k = kwargs.pop('generate_top_k', 50) - self.generate_top_p = kwargs.pop('generate_top_p', 0.0) + self.generate_top_p = kwargs.pop('generate_top_p', 1.0) self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) + self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0) + self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0) + self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0) + self.generate_batch_size = kwargs.pop('generate_batch_size', 1) + self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.) def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 27d42c552a..003e17a0d9 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module): return model - def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None, - temperature=None, top_k=None, top_p=None, repetition_penalty=None, - **model_kwargs): - """ Generic sequence generator for single-stack models with a LM head. + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {"input_ids": input_ids} - The method currently supports greedy decoding and sampling. See the - documentation of the `Sampler` class for more information about the - parameters related to sampling. + def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, + temperature=None, top_k=None, top_p=None, repetition_penalty=None, + bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None, + length_penalty=None, **kwargs): + """ Sequence generator for models with a LM head. + + The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling + and beam-search. + + Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM Params: **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) The sequence used as a prompt for the generation. If `None` the method initializes it as an empty `torch.LongTensor` of shape (1,) - **length**: (`optional`) int - The length of the sequence to be generated. + **max_length**: (`optional`) int + The max length of the sequence to be generated. Between 1 and infinity. Default to 20. **do_sample**: (`optional`) bool - If set to `False` we use greedy decoding; otherwise sampling. + If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling. + **num_beams**: (`optional`) int + Number of beams for beam search. 1 means no beam serach. Default to 1. **temperature**: (`optional`) float The value used to module the next token probabilities. - **k**: (`optional`) int - The parameter used for k-filtering. - **p**: (`optional`) float - The parameter for nucleus sampling. Must be between 0 and 1. + **top_k**: (`optional`) int + The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + **top_p**: (`optional`) float + The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. **repetition_penalty**: (`optional`) float - The parameter for repetition penalty. + The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1. """ - if input_ids is None: - input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device) - # We cannot generate if the model does not have a LM head if self.get_output_embeddings() is None: raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") - sampler_config = { - "k": k, - "p": p, - "do_sample": do_sample, - "temperature": temperature, - "repetition_penalty": repetition_penalty, - } + max_length = max_length if max_length is not None else self.config.generate_max_length + do_sample = do_sample if do_sample is not None else self.config.generate_do_sample + num_beams = num_beams if num_beams is not None else self.config.generate_num_beams + temperature = temperature if temperature is not None else self.config.generate_temperature + top_k = top_k if top_k is not None else self.config.generate_top_k + top_p = top_p if top_p is not None else self.config.generate_top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty + bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id + eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids + batch_size = batch_size if batch_size is not None else self.config.generate_batch_size + length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty - sampler = Sampler(**sampler_config) - generated_sequence = input_ids - for _ in trange(length): - arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) - outputs = self(**arguments) - next_tokens_logits = outputs[0][:, -1, :] - next_tokens = sampler.get_one_token( - next_tokens_logits, generated_sequence - ) - generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + if input_ids is not None: + batch_size = input_ids.shape[0] # overriden by the input batch_size + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] - return generated_sequence.squeeze(0) + assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer." + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." + assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer." + assert 0 < temperature, "`temperature` should be positive." + assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer." + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." + assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive." + assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer." + assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer." + assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \ + "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer." + assert 0 < length_penalty, "`length_penalty` should be strictely positive." - def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): - return model_kwargs.update({"input_ids": input_ids}) + if input_ids is None: + input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device) + else: + assert input_ids.dims() == 2 + + # current position and vocab size + cur_len = 1 + vocab_size = self.config.vocab_size + + # Expand input to num beams + input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) + input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) + + # generated hypotheses + generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)] + + # scores for each sentence in the beam + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view(-1) + + # cache compute states + pasts = None # self.prepare_pasts() + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) + scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) + scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) + assert scores.size() == (batch_size * num_beams, vocab_size) + + # select next words with scores + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) + + next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams) + + # next batch beam content + # list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch) + next_batch_beam = [] + + # for each sentence + for sent_id in range(batch_size): + + # if we are done with this sentence + done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item()) + if done[sent_id]: + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next words for this sentence + for idx, value in zip(next_words[sent_id], next_scores[sent_id]): + + # get beam and word IDs + beam_id = idx // vocab_size + word_id = idx % vocab_size + + # end of sentence, or next word + if word_id.item() in eos_token_ids or cur_len + 1 == max_length: + generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item()) + else: + next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id)) + + # the beam for next step is full + if len(next_sent_beam) == num_beams: + break + + # update next beam content + assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams + if len(next_sent_beam) == 0: + next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (sent_id + 1) + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_words = input_ids.new([x[1] for x in next_batch_beam]) + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) + + # re-order batch and internal states + input_ids = input_ids[beam_idx, :] + input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) + # TODO: Activate cache + # for k in cache.keys(): + # if k != 'slen': + # cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) + + # update current length + cur_len = cur_len + 1 + + # stop when we are done with each sentence + if all(done): + break + + # visualize hypotheses + # print([len(x) for x in generated_hyps], cur_len) + # globals().update( locals() ); + # !import code; code.interact(local=vars()) + # for ii in range(batch_size): + # for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True): + # print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist())) + # print("") + + # select the best hypotheses + tgt_len = src_len.new(batch_size) + best = [] + + for i, hypotheses in enumerate(generated_hyps): + best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] + tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol + best.append(best_hyp) + + # generate target batch + decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index) + for i, hypo in enumerate(best): + decoded[:tgt_len[i] - 1, i] = hypo + decoded[tgt_len[i] - 1, i] = self.eos_index + + # sanity check + assert (decoded == self.eos_index).sum() == 2 * batch_size + + return decoded, tgt_len + + +class BeamHypotheses(object): + + def __init__(self, n_hyp, max_length, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.n_hyp = n_hyp + self.hyp = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.hyp) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.n_hyp or score > self.worst_score: + self.hyp.append((score, hyp)) + if len(self) > self.n_hyp: + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) + del self.hyp[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + if len(self) < self.n_hyp: + return False + elif self.early_stopping: + return True + else: + return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty class Sampler(object): From bbc0c86f9b96b62b95853a18945f855c661a13b9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 17 Dec 2019 23:27:02 +0100 Subject: [PATCH 04/11] beam search + single beam decoding --- transformers/modeling_utils.py | 152 ++++++++++++++++++++++++++------- 1 file changed, 123 insertions(+), 29 deletions(-) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 003e17a0d9..52743d8c2f 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module): if isinstance(eos_token_ids, int): eos_token_ids = [eos_token_ids] - assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer." + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." - assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer." - assert 0 < temperature, "`temperature` should be positive." - assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer." + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." + assert temperature > 0, "`temperature` should be positive." + assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." - assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive." - assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer." - assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer." - assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \ + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer." + assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer." + assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \ "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." - assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer." - assert 0 < length_penalty, "`length_penalty` should be strictely positive." + assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer." + assert length_penalty > 0, "`length_penalty` should be strictely positive." if input_ids is None: input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device) else: - assert input_ids.dims() == 2 + assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)." # current position and vocab size - cur_len = 1 + cur_len = input_ids.shape[1] vocab_size = self.config.vocab_size + if num_beams > 1: + return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty, + num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size) + + return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size) + + def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size): + """ Generate a sentence without beam search (num_beams == 1). """ + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + + # cache compute states + pasts = None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size): + for _ in set(input_ids[i].tolist()): + next_token_logits[i, _] /= repetition_penalty + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + + # update generations and finished sentences + tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) + input_ids = torch.cat([input_ids, tokens_to_add], dim=-1) + for eos_token_id in eos_token_ids: + unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long()) + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + # add eos_token_ids to unfinished sentences + if cur_len == max_length: + input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0]) + + return input_ids + + def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty, + num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size): + """ Generate a sentence with beam search. """ # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module): scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) assert scores.size() == (batch_size * num_beams, vocab_size) - # select next words with scores - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) - _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) + # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams) @@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module): next_batch_beam = [] # for each sentence - for sent_id in range(batch_size): + for batch_ex in range(batch_size): # if we are done with this sentence - done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item()) - if done[sent_id]: + done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item()) + if done[batch_ex]: next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch continue @@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module): next_sent_beam = [] # next words for this sentence - for idx, value in zip(next_words[sent_id], next_scores[sent_id]): + for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]): # get beam and word IDs beam_id = idx // vocab_size @@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module): # end of sentence, or next word if word_id.item() in eos_token_ids or cur_len + 1 == max_length: - generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item()) + generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item()) else: - next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id)) + next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id)) # the beam for next step is full if len(next_sent_beam) == num_beams: @@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module): if len(next_sent_beam) == 0: next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch next_batch_beam.extend(next_sent_beam) - assert len(next_batch_beam) == num_beams * (sent_id + 1) + assert len(next_batch_beam) == num_beams * (batch_ex + 1) # sanity check / prepare next batch assert len(next_batch_beam) == batch_size * num_beams @@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module): # print("") # select the best hypotheses - tgt_len = src_len.new(batch_size) + tgt_len = input_ids.new(batch_size) best = [] for i, hypotheses in enumerate(generated_hyps): @@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module): best.append(best_hyp) # generate target batch - decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index) + decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): - decoded[:tgt_len[i] - 1, i] = hypo - decoded[tgt_len[i] - 1, i] = self.eos_index + decoded[i, :tgt_len[i] - 1] = hypo + decoded[i, tgt_len[i] - 1] = eos_token_ids[0] - # sanity check - assert (decoded == self.eos_index).sum() == 2 * batch_size + # # sanity check + # assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size - return decoded, tgt_len + return decoded + + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size x vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits class BeamHypotheses(object): From 77d397202ba3daa013c94696e9825de8e20145e8 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 17 Dec 2019 23:28:46 +0100 Subject: [PATCH 05/11] clean up dead code --- transformers/modeling_utils.py | 140 --------------------------------- 1 file changed, 140 deletions(-) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 52743d8c2f..0e285c4f6b 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -830,146 +830,6 @@ class BeamHypotheses(object): return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty -class Sampler(object): - r""" Sampler is used to generate sequences of ids from logit inputs. - - Greedy decoding, which consists in chosing the most probable token at each - step, is the default behaviour. Sampling with varying temperature, top_k - and nucleus filtering is also implemented. - - Attributes: - **device**: ``torch.device`` - Device on which the computations will be run. - **do_sample**: bool - Whether to sample or do greedy decoding. - **k**: int between 0 and vocab_size - Parameter for the top-k filtering - **p**: float between 0 and 1 - Parameter for the nucleus filtering - **temperature**: strictly positive float - Parameter used to modulate the distribution over ids. Low temperatures - put more emphasis on highly probably token while high temperatures tend - to smooth the probability distribution. - **repetition_penalty**: strictly postitive float - The penalty applied to repeating ids - """ - - def __init__( - self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 - ): - self.k = k - self.p = p - self.do_sample = do_sample - self.temperature = temperature - self.repetition_penalty = repetition_penalty - - self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False - - if self.p > 1: - warnings.warn( - """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). - However p is a probability and its value must lie between 0 and 1. In effect, no filtering - will be applied. If this is not the behavior you expect, change the value of p.""".format( - self.p - ) - ) - - def get_one_token(self, next_token_logits, past_sequence): - logits = self.apply_repetition_penalty(next_token_logits, past_sequence) - if self.do_sample: - logits = self.apply_temperature(logits) - logits = self.apply_top_k_filter(logits) - logits = self.apply_nucleus_filter(logits) - return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) - return torch.argmax(logits, dim=-1).unsqueeze(-1) - - def apply_repetition_penalty(self, logits, past_sequence): - """ Apply a penalty to tokens that appear more than once in the - generated sequence. - - .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer - language model for controllable generation." arXiv preprint - arXiv:1909.05858 (2019). - """ - if self.do_apply_repetition_penalty: - generated_token_idx = set(past_sequence[0].tolist()) - for token_idx in generated_token_idx: - logits[0, token_idx] /= self.repetition_penalty - return logits - - def apply_temperature(self, logits): - """ Shape the tokens' distribution through temperature. The higher the value - of the temperature, the more skewed towards high probability events the - distribution is. - - .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. - MIT press, 2016. - """ - # when dividing a float by 0, torch returns inf which in turns breaks the - # multinomial with an error message that is not very helpful. It is better - # for the user to break the execution and explain why. - if self.temperature == 0: - raise ZeroDivisionError( - """You are trying to sample with a temperature equal to 0. - If you wanted to do greedy sampling, set instead `do_sample` to False. - Otherwise set the temperature to a value different from 0.""" - ) - return logits / self.temperature - - def apply_top_k_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically we select the set of size k such that - the sum of its items' probabilities is maximum. - - .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural - story generation." arXiv preprint arXiv:1805.04833 (2018). - """ - if self.k > 0: - vocabulary_size = logits.size(-1) - if self.k > vocabulary_size: - warnings.warn( - """You provided a value for k ({}) that is larger than the vocabulary size ({}). - We adjusted k's value to the vocabulary size; if that was what you intended to do - we recommend setting k to 0 instead. It this is not the behavior you expected, - choose a value of k that is smaller than the vocabulary size.""".format( - self.k, vocabulary_size - ) - ) - self.k = vocabulary_size - - indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] - logits[indices_to_remove] = -float("Inf") - - return logits - - def apply_nucleus_filter(self, logits): - """ Use the probability distribution of the tokens to determine the set - to be sampled from. Specifically, choose the smallest set such that the - sum of its items' probabilities is greater than a number p in [0,1]. - - .. Holtzman, Ari, et al. "The curious case of neural text - degeneration." arXiv preprint arXiv:1904.09751 (2019). - """ - if self.p > 0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - sorted_probabilities = F.softmax(sorted_logits, dim=-1) - cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) - - # Remove tokens with cumulative probability above the threshold, - # but keep the first token above the threshold. - sorted_indices_to_remove = cumulative_probabilities > self.p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - dim=-1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits[indices_to_remove] = -float("Inf") - - return logits - - class Conv1D(nn.Module): def __init__(self, nf, nx): """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) From 641a8decdc6c34ce1837c9602fe84a65ec5b741a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 18 Dec 2019 10:43:48 +0100 Subject: [PATCH 06/11] clean up code and add arbitrary number of return sequences --- transformers/configuration_utils.py | 25 +-- transformers/modeling_encoder_decoder.py | 95 ---------- transformers/modeling_utils.py | 163 +++++++++++------ transformers/tests/sampling_test.py | 213 ----------------------- 4 files changed, 119 insertions(+), 377 deletions(-) delete mode 100644 transformers/tests/sampling_test.py diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 8c3e0a9f9c..456af3341c 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -62,18 +62,19 @@ class PretrainedConfig(object): self.is_decoder = kwargs.pop('is_decoder', False) # Parameters for sequence generation - self.generate_max_length = kwargs.pop('generate_max_length', 20) - self.generate_do_sample = kwargs.pop('generate_do_sample', False) - self.generate_num_beams = kwargs.pop('generate_num_beams', 1) - self.generate_temperature = kwargs.pop('generate_temperature', 1.0) - self.generate_top_k = kwargs.pop('generate_top_k', 50) - self.generate_top_p = kwargs.pop('generate_top_p', 1.0) - self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0) - self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0) - self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0) - self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0) - self.generate_batch_size = kwargs.pop('generate_batch_size', 1) - self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.) + self.max_length = kwargs.pop('max_length', 20) + self.do_sample = kwargs.pop('do_sample', False) + self.num_beams = kwargs.pop('num_beams', 1) + self.temperature = kwargs.pop('temperature', 1.0) + self.top_k = kwargs.pop('top_k', 50) + self.top_p = kwargs.pop('top_p', 1.0) + self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0) + self.bos_token_id = kwargs.pop('bos_token_id', 0) + self.pad_token_id = kwargs.pop('pad_token_id', 0) + self.eos_token_ids = kwargs.pop('eos_token_ids', 0) + self.batch_size = kwargs.pop('batch_size', 1) + self.length_penalty = kwargs.pop('length_penalty', 1.) + self.num_return_sequences = kwargs.pop('num_return_sequences', 1) def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index 3d8c812c2f..d69a75cc75 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -25,7 +25,6 @@ from torch import nn from tqdm import trange from .modeling_auto import AutoModel, AutoModelWithLMHead -from .modeling_utils import Sampler logger = logging.getLogger(__name__) @@ -203,100 +202,6 @@ class PreTrainedEncoderDecoder(nn.Module): return decoder_outputs + encoder_outputs - def decode( - self, - encoder_input_ids, - decoder_prompt_ids=None, - device=torch.device("cpu"), - length=10, - do_sample=False, - temperature=1.0, - k=9, - p=0., - repetition_penalty=1., - **kwargs - ): - """ Generic sequence generator for encoder-decoder models. - - For encoder-decoders the generation consists in: - - Performing a forward pass through the encoder once; - - Pass the encoder's hidden states to a decoding mechanism that - repeatedly calls the decoder to generate sequences. - - The method currently supports greedy decoding and sampling. See the - documentation of the `Sampler` class for more information about the - parameters related to sampling. - - Params: - **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) - The sequence to encode. - **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) - The sequence used as a prompt for the generation. If `None` the method initializes - it as an empty `torch.LongTensor` of shape (1,) - **device**: (`optional`) `torch.device` - The device on which the prompt_ids will be initialized if not provided. - **length**: (`optional`) int - The length of the sequence to be generated. - **do_sample**: (`optional`) bool - If set to `False` we use greedy decoding; otherwise sampling. - **temperature**: (`optional`) float - The value used to module the next token probabilities. - **k**: (`optional`) int - The parameter used for k-filtering. - **p**: (`optional`) float - The parameter for nucleus sampling. Must be between 0 and 1. - **repetition_penalty**: (`optional`) float - The parameter for repetition penalty. - """ - if decoder_prompt_ids is None: - decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) - - # When the model does not have a LM head `get_output_embeddings` - # returns `None`. We use this mechanism to determine whether we - # should proceed with decoding or not. - if self.decoder.get_output_embeddings() is None: - raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.") - - # The followings checks that the decoder is on the same device as the one - # that is specified. It only works for models that fit on one GPU. - decoder_device = next(self.decoder.parameters()).device - if decoder_device != decoder_prompt_ids.device: - warnings.warn( - "The decoder is not on the same device as the prompt. Expected {}, got {}.".format( - decoder_prompt_ids.device, decoder_device - ) - ) - - kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) - with torch.no_grad(): - encoder_outputs = self.encoder(encoder_input_ids, **kwargs) - encoder_hidden_states = encoder_outputs[0] - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - - sampler_config = { - "k": k, - "p": p, - "do_sample": do_sample, - "temperature": temperature, - "repetition_penalty": repetition_penalty, - } - return self._greedy_decode_or_sample( - decoder_prompt_ids, length, sampler_config, **kwargs_decoder - ) - - def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder): - sampler = Sampler(**sampler_config) - with torch.no_grad(): - generated_sequence = prompt_ids - for _ in trange(length): - arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder) - outputs = self.decoder(**arguments) - next_tokens_logits = outputs[0][:, -1, :] - next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence) - generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) - - return generated_sequence.squeeze(0) - @staticmethod def prepare_model_kwargs(**kwargs): """ Prepare the encoder and decoder's keyword arguments. diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 0e285c4f6b..6fa68a0db4 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -494,7 +494,7 @@ class PreTrainedModel(nn.Module): def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None, - length_penalty=None, **kwargs): + length_penalty=None, num_return_sequences=None, **kwargs): """ Sequence generator for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling @@ -526,18 +526,19 @@ class PreTrainedModel(nn.Module): if self.get_output_embeddings() is None: raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") - max_length = max_length if max_length is not None else self.config.generate_max_length - do_sample = do_sample if do_sample is not None else self.config.generate_do_sample - num_beams = num_beams if num_beams is not None else self.config.generate_num_beams - temperature = temperature if temperature is not None else self.config.generate_temperature - top_k = top_k if top_k is not None else self.config.generate_top_k - top_p = top_p if top_p is not None else self.config.generate_top_p - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty - bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id - eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids - batch_size = batch_size if batch_size is not None else self.config.generate_batch_size - length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty + max_length = max_length if max_length is not None else self.config.max_length + do_sample = do_sample if do_sample is not None else self.config.do_sample + num_beams = num_beams if num_beams is not None else self.config.num_beams + temperature = temperature if temperature is not None else self.config.temperature + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids + batch_size = batch_size if batch_size is not None else self.config.batch_size + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size @@ -547,8 +548,8 @@ class PreTrainedModel(nn.Module): assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." - assert temperature > 0, "`temperature` should be positive." - assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer." + assert temperature > 0, "`temperature` should be strictely positive." + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer." @@ -557,30 +558,41 @@ class PreTrainedModel(nn.Module): "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer." assert length_penalty > 0, "`length_penalty` should be strictely positive." + assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer." if input_ids is None: input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device) else: - assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)." + assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." # current position and vocab size cur_len = input_ids.shape[1] vocab_size = self.config.vocab_size if num_beams > 1: - return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty, - num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size) - + return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size, + num_return_sequences, + length_penalty, num_beams, vocab_size) return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size) + pad_token_id, eos_token_ids, batch_size, + num_return_sequences) def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size): - """ Generate a sentence without beam search (num_beams == 1). """ + pad_token_id, eos_token_ids, batch_size, + num_return_sequences): + """ Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + # Expand input to num return sequences + input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len) + input_ids = input_ids.contiguous().view(batch_size*num_return_sequences, cur_len) # (batch_size*num_return_sequences, cur_len) + # current position / max lengths / length of generated sentences / unfinished sentences - unfinished_sents = input_ids.new(batch_size).fill_(1) + unfinished_sents = input_ids.new(batch_size*num_return_sequences).fill_(1) # cache compute states pasts = None @@ -592,9 +604,9 @@ class PreTrainedModel(nn.Module): # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: - for i in range(batch_size): - for _ in set(input_ids[i].tolist()): - next_token_logits[i, _] /= repetition_penalty + for i in range(batch_size*num_return_sequences): + for previous_tokens in set(input_ids[i].tolist()): + next_token_logits[i, previous_tokens] /= repetition_penalty if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -603,16 +615,16 @@ class PreTrainedModel(nn.Module): # Top-p/top-k filtering next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) # Sample - next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) + next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1) else: # Greedy decoding - next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + next_token = torch.argmax(next_token_logits, dim=-1) # update generations and finished sentences tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) - input_ids = torch.cat([input_ids, tokens_to_add], dim=-1) + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) for eos_token_id in eos_token_ids: - unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long()) + unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long()) cur_len = cur_len + 1 # stop when there is a in each sentence, or if we exceed the maximul length @@ -621,13 +633,24 @@ class PreTrainedModel(nn.Module): # add eos_token_ids to unfinished sentences if cur_len == max_length: - input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0]) + input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0]) + + if num_return_sequences != 1: + input_ids = input_ids.view(batch_size, num_return_sequences, -1) return input_ids - def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty, - num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size): - """ Generate a sentence with beam search. """ + def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size, + num_return_sequences, + length_penalty, num_beams, vocab_size): + """ Generate `num_return_sequences` sequences per batch example with beam search. + We return the top-`num_return_sequences` beams. + `num_return_sequences` should be bigger than `num_beams` (we default to the min of both) + """ + num_return_sequences = min(num_return_sequences, num_beams) + # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -638,7 +661,7 @@ class PreTrainedModel(nn.Module): # scores for each sentence in the beam beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view(-1) + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states pasts = None # self.prepare_pasts() @@ -648,18 +671,40 @@ class PreTrainedModel(nn.Module): while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) - scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) - scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) - scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) - assert scores.size() == (batch_size * num_beams, vocab_size) + scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) + scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) - # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size * num_beams): + for previous_tokens in set(input_ids[i].tolist()): + scores[i, previous_tokens] /= repetition_penalty - # re-organize to group the beam together (we are keeping top hypothesis accross beams) - _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + scores = scores / temperature + # Top-p/top-k filtering + scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # (batch_size * num_beams, vocab_size) + # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search) + next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2) + # Compute next scores + _scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) + _scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2) + next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2) + # Match shape of greedy beam search + next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams) + next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams) + else: + # do greedy beam search + scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) + assert scores.size() == (batch_size * num_beams, vocab_size) + # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) + next_scores, next_words = torch.topk(_scores, 2*num_beams, dim=1, largest=True, sorted=True) - next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams) # next batch beam content @@ -733,32 +778,36 @@ class PreTrainedModel(nn.Module): # print("") # select the best hypotheses - tgt_len = input_ids.new(batch_size) - best = [] + tgt_len = input_ids.new(batch_size, num_return_sequences) + bests = [] for i, hypotheses in enumerate(generated_hyps): - best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] - tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol - best.append(best_hyp) + best_hyps = [hyp[1] for hyp in sorted(hypotheses.hyp, key=lambda hyp: hyp[0])[-num_return_sequences:]] + for j, hyp in enumerate(best_hyps): + tgt_len[i, j] = len(hyp) + 1 # +1 for the symbol + bests.append(best_hyps) # generate target batch - decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) - for i, hypo in enumerate(best): - decoded[i, :tgt_len[i] - 1] = hypo - decoded[i, tgt_len[i] - 1] = eos_token_ids[0] + decoded = input_ids.new(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id) + for i, hyps in enumerate(bests): + for j, hypo in enumerate(hyps): + decoded[i, j, :tgt_len[i, j] - 1] = hypo + decoded[i, j, tgt_len[i, j] - 1] = eos_token_ids[0] + if num_return_sequences == 1: + decoded = decoded.squeeze(1) # # sanity check # assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size return decoded -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): +def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size x vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ @@ -768,7 +817,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value - if top_p > 0.0: + if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) diff --git a/transformers/tests/sampling_test.py b/transformers/tests/sampling_test.py deleted file mode 100644 index 98cc23bf2b..0000000000 --- a/transformers/tests/sampling_test.py +++ /dev/null @@ -1,213 +0,0 @@ -# coding=utf-8 -import sys -import unittest - -import numpy as np -import pytest - -from transformers import is_torch_available - -if is_torch_available(): - import torch - - from transformers import ( - BertConfig, - BertModel, - GPT2Config, - GPT2LMHeadModel, - OpenAIGPTConfig, - OpenAIGPTLMHeadModel, - TransfoXLConfig, - TransfoXLLMHeadModel, - XLMConfig, - XLMWithLMHeadModel, - XLNetConfig, - XLNetLMHeadModel, - Model2Model, - ) - from transformers.modeling_utils import Sampler -else: - pytestmark = pytest.mark.skip("Require Torch") - - -class SamplerTest(unittest.TestCase): - def test_nucleus_sampling(self): - inf = -float("Inf") - test_cases = ( - { - "p": 0, - "logits": torch.tensor([0.3, 0.1, 0.2]), - "expected": torch.tensor([0.3, 0.1, 0.2]), - }, - { - "p": 0.01, - "logits": torch.tensor([0.3, 0.1, 0.2]), - "expected": torch.tensor([0.3, inf, inf]), - }, - { - "p": 1, - "logits": torch.tensor([0.3, 0.1, 0.2]), - "expected": torch.tensor([0.3, 0.1, 0.2]), - }, - { - "p": 0.2, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, inf, inf]), - }, - { - "p": 0.71, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, inf, 0.2]), - }, - { - "p": 0.71, - "logits": torch.tensor([0.1, 0.7, 0.2]), - "expected": torch.tensor([inf, 0.7, 0.2]), - }, - { - "p": 0.71, - "logits": torch.tensor([0.7, 0.2, 0.1]), - "expected": torch.tensor([0.7, 0.2, inf]), - }, - { - "p": 0.91, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, 0.1, 0.2]), - }, - ) - for case in test_cases: - config = { - "do_sample": True, - "temperature": 1.0, - "k": 0, - "p": case["p"], - "repetition_penalty": 1.0, - } - sampler = Sampler(**config) - filtered_logits = sampler.apply_nucleus_filter(case["logits"]) - np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) - - def test_top_k_filter(self): - inf = -float("Inf") - test_cases = ( - { - "k": 0, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, 0.1, 0.2]), - }, - { - "k": 1, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, inf, inf]), - }, - { - "k": 2, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, inf, 0.2]), - }, - { - "k": 3, - "logits": torch.tensor([0.7, 0.1, 0.2]), - "expected": torch.tensor([0.7, 0.1, 0.2]), - }, - ) - for case in test_cases: - config = { - "do_sample": True, - "temperature": 1.0, - "k": case["k"], - "p": 0, - "repetition_penalty": 1.0, - } - sampler = Sampler(**config) - filtered_logits = sampler.apply_top_k_filter(case["logits"]) - np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) - - @pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2") - def test_wrong_k_value(self): - case = {"k": 10, "vocab_size": 5} - config = { - "do_sample": True, - "temperature": 1.0, - "k": case["k"], - "p": 0, - "repetition_penalty": 1.0, - } - sampler = Sampler(**config) - next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0) - past_sequence = torch.tensor([]) - with self.assertWarns(UserWarning): - _ = sampler.get_one_token(next_token_logits, past_sequence) - - def test_zero_temperature(self): - temperature = 0 - config = { - "do_sample": True, - "temperature": temperature, - "k": 0, - "p": 0, - "repetition_penalty": 1.0, - } - sampler = Sampler(**config) - next_token_logits = torch.rand(10).unsqueeze(0) - past_sequence = torch.tensor([]) - with self.assertRaises(ZeroDivisionError): - _ = sampler.get_one_token(next_token_logits, past_sequence) - - -class SamplerSingleStackTest(unittest.TestCase): - def test_raises_exception_when_no_LM_head(self): - models = [BertModel(BertConfig())] - for model in models: - with self.assertRaises(AttributeError): - model.decode() - - @pytest.mark.slow - def test_forward_pass_and_output_length(self): - models = { - "XLNet": XLNetLMHeadModel(XLNetConfig()), - "XLM": XLMWithLMHeadModel(XLMConfig()), - "TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()), - "GPT2": GPT2LMHeadModel(GPT2Config()), - "GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()), - } - kwargs = { - "XLNet": {}, - "XLM": {"mask_token": 0}, - "TransfoXL": {}, - "GPT2": {}, - "GPT": {}, - } - prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) - generated_length = 5 - expected_length = 8 - - for name, model in models.items(): - kwargs_model = kwargs[name] - output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model) - self.assertEqual(len(output), expected_length) - - -class SamplerEncoderDecoderTest(unittest.TestCase): - @pytest.mark.slow - def test_forward_pass_and_output_length(self): - model = Model2Model.from_pretrained("bert-base-uncased") - - encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) - prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) - generated_length = 5 - expected_length = 8 - - output = model.decode( - encoder_input_ids, - decoder_prompt_ids=prompt, - k=2, - p=0.5, - repetition_penalty=2, - length=generated_length, - ) - self.assertEqual(len(output), expected_length) - - -if __name__ == "__main__": - unittest.main() From 8e5587fb7935e3040c11118041ff729d33adcb09 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 18 Dec 2019 11:32:37 +0100 Subject: [PATCH 07/11] few fixes on sampling --- transformers/modeling_utils.py | 95 +++++++++++++++------------------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 6fa68a0db4..bbfb0614ad 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -23,14 +23,12 @@ import json import logging import os from io import open -import warnings import six import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F -from tqdm import trange from .configuration_utils import PretrainedConfig from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME @@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module): "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) - # Save config in model self.config = config @@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module): # Tie weights if needed self.tie_weights() - # Initialize decoding head if we have output embeddings - - def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. @@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module): cur_len = input_ids.shape[1] vocab_size = self.config.vocab_size + if num_return_sequences != 1: + # Expand input to num return sequences + input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len) + input_ids = input_ids.contiguous().view(batch_size * num_return_sequences, cur_len) # (batch_size * num_return_sequences, cur_len) + effective_batch_size = batch_size * num_return_sequences + else: + effective_batch_size = batch_size + if num_beams > 1: - return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, - temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size, - num_return_sequences, - length_penalty, num_beams, vocab_size) - return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, + output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, effective_batch_size, + length_penalty, num_beams, vocab_size) + else: + output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size, - num_return_sequences) + pad_token_id, eos_token_ids, effective_batch_size) + + if num_return_sequences != 1: + output = output.view(batch_size, num_return_sequences, -1) + return output def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size, - num_return_sequences): - """ Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1). + pad_token_id, eos_token_ids, batch_size): + """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. """ - # Expand input to num return sequences - input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len) - input_ids = input_ids.contiguous().view(batch_size*num_return_sequences, cur_len) # (batch_size*num_return_sequences, cur_len) - # current position / max lengths / length of generated sentences / unfinished sentences - unfinished_sents = input_ids.new(batch_size*num_return_sequences).fill_(1) + unfinished_sents = input_ids.new(batch_size).fill_(1) # cache compute states pasts = None @@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module): # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: - for i in range(batch_size*num_return_sequences): + for i in range(batch_size): for previous_tokens in set(input_ids[i].tolist()): next_token_logits[i, previous_tokens] /= repetition_penalty @@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module): if cur_len == max_length: input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0]) - if num_return_sequences != 1: - input_ids = input_ids.view(batch_size, num_return_sequences, -1) - return input_ids def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, batch_size, - num_return_sequences, length_penalty, num_beams, vocab_size): - """ Generate `num_return_sequences` sequences per batch example with beam search. - We return the top-`num_return_sequences` beams. - `num_return_sequences` should be bigger than `num_beams` (we default to the min of both) + """ Generate sequences for each example with beam search. """ - num_return_sequences = min(num_return_sequences, num_beams) - # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module): if temperature != 1.0: scores = scores / temperature # Top-p/top-k filtering - scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) # (batch_size * num_beams, vocab_size) + scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size) # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search) next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2) # Compute next scores @@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module): # print("") # select the best hypotheses - tgt_len = input_ids.new(batch_size, num_return_sequences) - bests = [] + tgt_len = input_ids.new(batch_size) + best = [] for i, hypotheses in enumerate(generated_hyps): - best_hyps = [hyp[1] for hyp in sorted(hypotheses.hyp, key=lambda hyp: hyp[0])[-num_return_sequences:]] - for j, hyp in enumerate(best_hyps): - tgt_len[i, j] = len(hyp) + 1 # +1 for the symbol - bests.append(best_hyps) + best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] + tgt_len[i] = len(best_hyp) + 1 # +1 for the symbol + best.append(best_hyp) # generate target batch - decoded = input_ids.new(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id) - for i, hyps in enumerate(bests): - for j, hypo in enumerate(hyps): - decoded[i, j, :tgt_len[i, j] - 1] = hypo - decoded[i, j, tgt_len[i, j] - 1] = eos_token_ids[0] - - if num_return_sequences == 1: - decoded = decoded.squeeze(1) - # # sanity check - # assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size + decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) + for i, hypo in enumerate(best): + decoded[i, :tgt_len[i] - 1] = hypo + decoded[i, tgt_len[i] - 1] = eos_token_ids[0] return decoded -def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')): +def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: - logits: logits distribution shape (batch size x vocabulary size) + logits: logits distribution shape (batch size, vocabulary size) if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ - top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value @@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf') sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - # Remove tokens with cumulative probability above the threshold + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 From 3d2096f516e99da79f1c6c60a48f828b4e7733ef Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 18 Dec 2019 11:50:54 +0100 Subject: [PATCH 08/11] further cleanup --- examples/run_generation.py | 13 ++++--- transformers/configuration_xlm.py | 4 +++ transformers/modeling_utils.py | 18 ++++++---- transformers/modeling_xlm.py | 39 +++++++------------- transformers/modeling_xlnet.py | 58 +++++++++++++----------------- transformers/tokenization_utils.py | 2 +- 6 files changed, 58 insertions(+), 76 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 2075ad8457..8121f4f5aa 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -91,7 +91,7 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text): def prepare_xlm_input(args, model, tokenizer, prompt_text): - kwargs = {"language": None, "mask_token": None} + kwargs = {"language": None, "mask_token_id": None} # Set the language use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb @@ -112,7 +112,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): # XLM masked-language modeling (MLM) models need masked token is_xlm_mlm = "mlm" in args.model_name_or_path if is_xlm_mlm: - kwargs["mask_token"] = tokenizer.mask_token_id + kwargs["mask_token_id"] = tokenizer.mask_token_id return prompt_text, kwargs @@ -204,14 +204,13 @@ def main(): prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0) - output_sequences = model.decode( - prompt_ids=encoded_prompt, + output_sequences = model.generate( + intput_ids=encoded_prompt, length=args.length, temperature=args.temperature, - k=args.k, - p=args.p, + top_k=args.k, + top_p=args.p, repetition_penalty=args.repetition_penalty, - device=args.device, **model_kwargs, ) diff --git a/transformers/configuration_xlm.py b/transformers/configuration_xlm.py index fa3a5f40f6..1938b85741 100644 --- a/transformers/configuration_xlm.py +++ b/transformers/configuration_xlm.py @@ -113,6 +113,8 @@ class XLMConfig(PretrainedConfig): summary_first_dropout=0.1, start_n_top=5, end_n_top=5, + mask_token_id = 0, + lang_id = 0, **kwargs): """Constructs XLMConfig. """ @@ -156,6 +158,8 @@ class XLMConfig(PretrainedConfig): self.summary_first_dropout = summary_first_dropout self.start_n_top = start_n_top self.end_n_top = end_n_top + self.mask_token_id = mask_token_id + self.lang_id = lang_id else: raise ValueError("First argument must be either a vocabulary size (int)" " or the path to a pretrained model config file (str)") diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index bbfb0614ad..f55c209ac0 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -488,7 +488,7 @@ class PreTrainedModel(nn.Module): def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None, - length_penalty=None, num_return_sequences=None, **kwargs): + length_penalty=None, num_return_sequences=None, **model_kwargs): """ Sequence generator for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling @@ -575,11 +575,13 @@ class PreTrainedModel(nn.Module): output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, effective_batch_size, - length_penalty, num_beams, vocab_size) + length_penalty, num_beams, vocab_size, + **model_kwargs) else: output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, effective_batch_size) + pad_token_id, eos_token_ids, effective_batch_size, + **model_kwargs) if num_return_sequences != 1: output = output.view(batch_size, num_return_sequences, -1) @@ -587,7 +589,8 @@ class PreTrainedModel(nn.Module): def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size): + pad_token_id, eos_token_ids, batch_size, + **model_kwargs): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. """ @@ -598,7 +601,7 @@ class PreTrainedModel(nn.Module): pasts = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -640,7 +643,8 @@ class PreTrainedModel(nn.Module): def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, batch_size, - length_penalty, num_beams, vocab_size): + length_penalty, num_beams, vocab_size, + **model_kwargs): """ Generate sequences for each example with beam search. """ # Expand input to num beams @@ -662,7 +666,7 @@ class PreTrainedModel(nn.Module): done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 295fff7943..6691b0f60b 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -639,6 +639,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.proj + def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id + lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id + + mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device) + input_ids = torch.cat([input_ids, mask_token], dim=1) + if lang_id is not None: + langs = torch.full_like(input_ids, lang_id) + else: + langs = None + return {"input_ids": input_ids, "langs": langs} + def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None): transformer_outputs = self.transformer(input_ids, @@ -657,33 +669,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): return outputs - def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): - mask_token = model_kwargs.pop("mask_token", None) - language = model_kwargs.pop("language", None) - input_ids = self._append_mask_token(input_ids, mask_token) - langs = self._create_language_embeddings(input_ids, language) - arguments = {"input_ids": input_ids, "langs": langs} - arguments.update(model_kwargs) - - return arguments - - @staticmethod - def _append_mask_token(sequence, mask_token_id): - """ Append a [MASK] token at the end of the sequence that the MLM model - is going to try to predict. - """ - if mask_token_id is not None: - tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long) - return torch.cat((sequence, tokens_to_append), dim=1) - - return sequence - - @staticmethod - def _create_language_embeddings(sequence, language): - if language is not None: - return torch.tensor([language] * sequence.shape[1]).view(1, -1) - return None - @add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 2153923dd2..26b95076cd 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def get_output_embeddings(self): return self.lm_loss + def prepare_inputs_for_generation(self, input_ids, **model_kwargs): + # Add dummy token at the end (no attention on this one) + dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + # Build permutation mask so that previous tokens don't see last token + perm_mask = torch.zeros( + (input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), + dtype=torch.float, device=input_ids.device + ) + perm_mask[:, :, -1] = 1.0 + + # We'll only predict the last token + target_mapping = torch.zeros( + (input_ids.shape[0], 1, input_ids.shape[1]), + dtype=torch.float, device=input_ids.device + ) + target_mapping[0, 0, -1] = 1.0 + + return {"input_ids": input_ids, + "perm_mask": perm_mask, + "target_mapping": target_mapping + } + def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None): transformer_outputs = self.transformer(input_ids, @@ -972,40 +996,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): return outputs # return (loss), logits, (mems), (hidden states), (attentions) - def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): - input_ids = self._add_dummy_token(input_ids) - perm_mask = self._create_perm_mask(input_ids) - target_mapping = self._create_target_mapping(input_ids) - arguments = { - "input_ids": input_ids, - "perm_mask": perm_mask, - "target_mapping": target_mapping, - } - return arguments - - @staticmethod - def _add_dummy_token(sequence): - dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long) - return torch.cat((sequence, dummy), dim=1) - - @staticmethod - def _create_perm_mask(sequence): - mask = torch.zeros( - (sequence.shape[0], sequence.shape[1], sequence.shape[1]), - dtype=torch.float, - ) - mask[:, :, -1] = 1.0 # Previous tokens don't see last token - return mask - - @staticmethod - def _create_target_mapping(sequence): - target_mapping = torch.zeros( - (sequence.shape[0], 1, sequence.shape[1]), - dtype=torch.float, - ) - target_mapping[0, 0, -1] = 1.0 # predict last token - return target_mapping - @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index f4395cd82c..2e0d6caef2 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -761,7 +761,7 @@ class PreTrainedTokenizer(object): padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length. The tokenizer padding sides are handled by the following strings: - 'left': pads on the left of the sequences - - 'right': pads on the right of the sequences + - 'right': pads on the right of the sequences Defaults to False: no padding. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. From 1c37746892a5fd680e88264346197bb313c8dd08 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 21 Dec 2019 13:52:49 +0100 Subject: [PATCH 09/11] fixing run_generation --- examples/run_generation.py | 9 ++++----- transformers/configuration_utils.py | 1 - transformers/modeling_utils.py | 9 +++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 8121f4f5aa..67e1da7413 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -156,7 +156,7 @@ def main(): parser.add_argument("--length", type=int, default=20) parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") - parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling") + parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling") parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") parser.add_argument("--k", type=int, default=0) parser.add_argument("--p", type=float, default=0.9) @@ -187,7 +187,6 @@ def main(): tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) - model.eval() args.length = adjust_length_to_model( args.length, max_sequence_length=model.config.max_position_embeddings @@ -202,11 +201,11 @@ def main(): if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) - encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0) + encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt') output_sequences = model.generate( - intput_ids=encoded_prompt, - length=args.length, + input_ids=encoded_prompt, + max_length=args.length, temperature=args.temperature, top_k=args.k, top_p=args.p, diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 456af3341c..ceb032a57c 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -72,7 +72,6 @@ class PretrainedConfig(object): self.bos_token_id = kwargs.pop('bos_token_id', 0) self.pad_token_id = kwargs.pop('pad_token_id', 0) self.eos_token_ids = kwargs.pop('eos_token_ids', 0) - self.batch_size = kwargs.pop('batch_size', 1) self.length_penalty = kwargs.pop('length_penalty', 1.) self.num_return_sequences = kwargs.pop('num_return_sequences', 1) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index f55c209ac0..5b28d5b755 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -485,9 +485,10 @@ class PreTrainedModel(nn.Module): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} + @torch.no_grad() def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, - bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None, + bos_token_id=None, pad_token_id=None, eos_token_ids=None, length_penalty=None, num_return_sequences=None, **model_kwargs): """ Sequence generator for models with a LM head. @@ -530,19 +531,20 @@ class PreTrainedModel(nn.Module): bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids - batch_size = batch_size if batch_size is not None else self.config.batch_size length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size + else: + batch_size = 1 if isinstance(eos_token_ids, int): eos_token_ids = [eos_token_ids] assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." - assert temperature > 0, "`temperature` should be strictely positive." + # assert temperature > 0, "`temperature` should be strictely positive." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." @@ -550,7 +552,6 @@ class PreTrainedModel(nn.Module): assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer." assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \ "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." - assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer." assert length_penalty > 0, "`length_penalty` should be strictely positive." assert isinstance(num_return_sequences, int) and num_return_sequences > 0, "`num_return_sequences` should be a strictely positive integer." From 300ec3003c282c5e3f06b33509af10dd0336d0ba Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 21 Dec 2019 14:02:19 +0100 Subject: [PATCH 10/11] fixing run_generation example - using torch.no_grad --- examples/run_generation.py | 31 ++++++++++++++----------------- transformers/configuration_xlm.py | 4 ++-- transformers/modeling_utils.py | 29 +++++++++++++---------------- transformers/modeling_xlm.py | 6 +++--- 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 67e1da7413..ade85f0269 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text): logger.info( "WARNING! You are not starting your generation from a control code so you won't get good results" ) - return prompt_text, {} + return prompt_text def prepare_xlm_input(args, model, tokenizer, prompt_text): - kwargs = {"language": None, "mask_token_id": None} + # kwargs = {"language": None, "mask_token_id": None} # Set the language use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb @@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): + str(list(available_languages)) + " >>> " ) - kwargs["language"] = tokenizer.lang2id[language] + # kwargs["language"] = tokenizer.lang2id[language] + # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers # XLM masked-language modeling (MLM) models need masked token - is_xlm_mlm = "mlm" in args.model_name_or_path - if is_xlm_mlm: - kwargs["mask_token_id"] = tokenizer.mask_token_id + # is_xlm_mlm = "mlm" in args.model_name_or_path + # if is_xlm_mlm: + # kwargs["mask_token_id"] = tokenizer.mask_token_id - return prompt_text, kwargs + return prompt_text def prepare_xlnet_input(args, _, tokenizer, prompt_text): @@ -179,8 +180,8 @@ def main(): try: args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - except KeyError as ke: - raise ke( + except KeyError: + raise KeyError( "the model {} you specified is not supported. You are welcome to add it and open a PR :)" ) @@ -197,10 +198,9 @@ def main(): # Different models need different input formatting and/or extra arguments requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() - model_kwargs = {} if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) - prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) + prompt_text = prepare_input(args, model, tokenizer, prompt_text) encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt') output_sequences = model.generate( @@ -210,14 +210,11 @@ def main(): top_k=args.k, top_p=args.p, repetition_penalty=args.repetition_penalty, - **model_kwargs, ) - generated_sequence = output_sequences.tolist()[ - encoded_prompt.size(1) : - ] # adapted to case where num_samples > 1 - text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] + generated_sequence = output_sequences.tolist() + text = [tokenizer.decode(seq, clean_up_tokenization_spaces=True) for seq in generated_sequence] + # text = text[: text.find(args.stop_token) if args.stop_token else None] print(text) diff --git a/transformers/configuration_xlm.py b/transformers/configuration_xlm.py index 1938b85741..1134c7ab61 100644 --- a/transformers/configuration_xlm.py +++ b/transformers/configuration_xlm.py @@ -113,8 +113,8 @@ class XLMConfig(PretrainedConfig): summary_first_dropout=0.1, start_n_top=5, end_n_top=5, - mask_token_id = 0, - lang_id = 0, + mask_token_id=0, + lang_id=0, **kwargs): """Constructs XLMConfig. """ diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 5b28d5b755..005252c141 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -489,7 +489,7 @@ class PreTrainedModel(nn.Module): def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_ids=None, - length_penalty=None, num_return_sequences=None, **model_kwargs): + length_penalty=None, num_return_sequences=None): """ Sequence generator for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling @@ -519,7 +519,8 @@ class PreTrainedModel(nn.Module): # We cannot generate if the model does not have a LM head if self.get_output_embeddings() is None: - raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + raise AttributeError("You tried to generate sequences with a model that does not have a LM Head." + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`)") max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample @@ -544,7 +545,7 @@ class PreTrainedModel(nn.Module): assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." - # assert temperature > 0, "`temperature` should be strictely positive." + # assert temperature >= 0, "`temperature` should be positive." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." @@ -576,13 +577,11 @@ class PreTrainedModel(nn.Module): output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, effective_batch_size, - length_penalty, num_beams, vocab_size, - **model_kwargs) + length_penalty, num_beams, vocab_size) else: output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, effective_batch_size, - **model_kwargs) + pad_token_id, eos_token_ids, effective_batch_size) if num_return_sequences != 1: output = output.view(batch_size, num_return_sequences, -1) @@ -590,19 +589,18 @@ class PreTrainedModel(nn.Module): def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, - pad_token_id, eos_token_ids, batch_size, - **model_kwargs): + pad_token_id, eos_token_ids, batch_size): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. """ # current position / max lengths / length of generated sentences / unfinished sentences unfinished_sents = input_ids.new(batch_size).fill_(1) - # cache compute states + # TODO: add cached compute states pasts = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -614,7 +612,7 @@ class PreTrainedModel(nn.Module): if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: + if temperature > 0 and temperature != 1.0: next_token_logits = next_token_logits / temperature # Top-p/top-k filtering next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) @@ -644,8 +642,7 @@ class PreTrainedModel(nn.Module): def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, temperature, top_k, top_p, repetition_penalty, pad_token_id, eos_token_ids, batch_size, - length_penalty, num_beams, vocab_size, - **model_kwargs): + length_penalty, num_beams, vocab_size): """ Generate sequences for each example with beam search. """ # Expand input to num beams @@ -667,7 +664,7 @@ class PreTrainedModel(nn.Module): done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) @@ -679,7 +676,7 @@ class PreTrainedModel(nn.Module): if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: + if temperature > 0 and temperature != 1.0: scores = scores / temperature # Top-p/top-k filtering scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size) diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 6691b0f60b..35bada92af 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -639,9 +639,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.proj - def prepare_inputs_for_generation(self, input_ids, **model_kwargs): - mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id - lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id + def prepare_inputs_for_generation(self, input_ids, **kwargs): + mask_token_id = self.config.mask_token_id + lang_id = self.config.lang_id mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, mask_token], dim=1) From f86ed2318917edc9aa8e21b97f292fd623ad5273 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 21 Dec 2019 14:13:06 +0100 Subject: [PATCH 11/11] update doc --- transformers/modeling_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 25124f1fda..05e5ed3573 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -529,6 +529,16 @@ class PreTrainedModel(nn.Module): The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. **repetition_penalty**: (`optional`) float The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1. + **bos_token_id**: (`optional`) int + Beginning of sentence token if no prompt is provided. Default to 0. + **eos_token_ids**: (`optional`) int or list of int + End of sequence token or list of tokens to stop the generation. Default to 0. + **length_penalty**: (`optional`) int + Exponential penalty to the length. Default to 0. + **length_penalty**: (`optional`) float + Exponential penalty to the length. Default to 1. + **num_return_sequences**: (`optional`) int + The number of independantly computed returned sequences for each element in the batch. Default to 1. """ # We cannot generate if the model does not have a LM head