diff --git a/examples/run_generation.py b/examples/run_generation.py index 2d917660cf..2075ad8457 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import argparse import logging -from tqdm import trange import torch -import torch.nn.functional as F import numpy as np -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig - from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer @@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer from transformers import XLMWithLMHeadModel, XLMTokenizer -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) logger = logging.getLogger(__name__) MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ()) - MODEL_CLASSES = { - 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), - 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), - 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), - 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), - 'xlm': (XLMWithLMHeadModel, XLMTokenizer), + "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), + "ctrl": (CTRLLMHeadModel, CTRLTokenizer), + "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + "xlnet": (XLNetLMHeadModel, XLNetTokenizer), + "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), + "xlm": (XLMWithLMHeadModel, XLMTokenizer), } # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia @@ -75,81 +71,78 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) - -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size x vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - top_k = min(top_k, logits.size(-1)) # Safety check - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits +# +# Functions to prepare models' input +# -def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, - is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'): - context = torch.tensor(context, dtype=torch.long, device=device) - context = context.unsqueeze(0).repeat(num_samples, 1) - generated = context - with torch.no_grad(): - for _ in trange(length): +def prepare_ctrl_input(args, _, tokenizer, prompt_text): + if args.temperature > 0.7: + logger.info( + "CTRL typically works better with lower temperatures (and lower top_k)." + ) - inputs = {'input_ids': generated} - if is_xlnet: - # XLNet is a direct (predict same token, not next token) and bi-directional model by default - # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring) - input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1) - perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device) - perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token - target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device) - target_mapping[0, 0, -1] = 1.0 # predict last token - inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} + encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) + if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): + logger.info( + "WARNING! You are not starting your generation from a control code so you won't get good results" + ) + return prompt_text, {} - if is_xlm_mlm and xlm_mask_token: - # XLM MLM models are direct models (predict same token, not next token) - # => need one additional dummy token in the input (will be masked and guessed) - input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) - inputs = {'input_ids': input_ids} - if xlm_lang is not None: - inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) +def prepare_xlm_input(args, model, tokenizer, prompt_text): + kwargs = {"language": None, "mask_token": None} - outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) - next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) + # Set the language + use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb + if hasattr(model.config, "lang2id") and use_lang_emb: + available_languages = model.config.lang2id.keys() + if args.xlm_language in available_languages: + language = args.xlm_language + else: + language = None + while language not in available_languages: + language = input( + "Using XLM. Select language in " + + str(list(available_languages)) + + " >>> " + ) + kwargs["language"] = tokenizer.lang2id[language] - # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) - for i in range(num_samples): - for _ in set(generated[i].tolist()): - next_token_logits[i, _] /= repetition_penalty - - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - if temperature == 0: # greedy sampling: - next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) - else: - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - generated = torch.cat((generated, next_token), dim=1) - return generated + # XLM masked-language modeling (MLM) models need masked token + is_xlm_mlm = "mlm" in args.model_name_or_path + if is_xlm_mlm: + kwargs["mask_token"] = tokenizer.mask_token_id + + return prompt_text, kwargs + + +def prepare_xlnet_input(args, _, tokenizer, prompt_text): + prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + return prompt_text, {} + + +def prepare_transfoxl_input(args, _, tokenizer, prompt_text): + prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text + return prompt_text, {} + + +PREPROCESSING_FUNCTIONS = { + "ctrl": prepare_ctrl_input, + "xlm": prepare_xlm_input, + "xlnet": prepare_xlnet_input, + "transfo-xl": prepare_transfoxl_input, +} + + +def adjust_length_to_model(length, max_sequence_length): + if length < 0 and max_sequence_length > 0: + length = max_sequence_length + elif 0 < max_sequence_length < length: + length = max_sequence_length # No generation bigger than model size + elif length < 0: + length = MAX_LENGTH # avoid infinite loop + return length def main(): @@ -157,104 +150,81 @@ def main(): parser.add_argument("--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--padding_text", type=str, default="") - parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=20) - parser.add_argument("--num_samples", type=int, default=1) - parser.add_argument("--temperature", type=float, default=1.0, - help="temperature of 0 implies greedy sampling") - parser.add_argument("--repetition_penalty", type=float, default=1.0, - help="primarily useful for CTRL model; in that case, use 1.2") - parser.add_argument("--top_k", type=int, default=0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--no_cuda", action='store_true', - help="Avoid using CUDA when available") - parser.add_argument('--seed', type=int, default=42, - help="random seed for initialization") - parser.add_argument('--stop_token', type=str, default=None, - help="Token at which text generation is stopped") + parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") + + parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling") + parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--p", type=float, default=0.9) + + parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.") + parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") + + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") args = parser.parse_args() - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) args.n_gpu = torch.cuda.device_count() set_seed(args) - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + # Initialize the model and tokenizer + try: + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + except KeyError as ke: + raise ke( + "the model {} you specified is not supported. You are welcome to add it and open a PR :)" + ) + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) model.eval() - if args.length < 0 and model.config.max_position_embeddings > 0: - args.length = model.config.max_position_embeddings - elif 0 < model.config.max_position_embeddings < args.length: - args.length = model.config.max_position_embeddings # No generation bigger than model size - elif args.length < 0: - args.length = MAX_LENGTH # avoid infinite loop - + args.length = adjust_length_to_model( + args.length, max_sequence_length=model.config.max_position_embeddings + ) logger.info(args) - if args.model_type in ["ctrl"]: - if args.temperature > 0.7: - logger.info('CTRL typically works better with lower temperatures (and lower top_k).') - while True: - xlm_lang = None - # XLM Language usage detailed in the issues #1414 - if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \ - and model.config.use_lang_emb: - if args.xlm_lang: - language = args.xlm_lang - else: - language = None - while language not in tokenizer.lang2id.keys(): - language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") - xlm_lang = tokenizer.lang2id[language] + prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") - # XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence) - is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path - if is_xlm_mlm: - xlm_mask_token = tokenizer.mask_token_id - else: - xlm_mask_token = None + # Different models need different input formatting and/or extra arguments + requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() + model_kwargs = {} + if requires_preprocessing: + prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) + prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text) + encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0) - raw_text = args.prompt if args.prompt else input("Model prompt >>> ") - if args.model_type in ["transfo-xl", "xlnet"]: - # Models with memory likes to have a long prompt for short inputs. - raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text - context_tokens = tokenizer.encode(raw_text, add_special_tokens=False) - if args.model_type == "ctrl": - if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): - logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") - out = sample_sequence( - model=model, - context=context_tokens, - num_samples=args.num_samples, - length=args.length, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - repetition_penalty=args.repetition_penalty, - is_xlnet=bool(args.model_type == "xlnet"), - is_xlm_mlm=is_xlm_mlm, - xlm_mask_token=xlm_mask_token, - xlm_lang=xlm_lang, - device=args.device, - ) - out = out[:, len(context_tokens):].tolist() - for o in out: - text = tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] + output_sequences = model.decode( + prompt_ids=encoded_prompt, + length=args.length, + temperature=args.temperature, + k=args.k, + p=args.p, + repetition_penalty=args.repetition_penalty, + device=args.device, + **model_kwargs, + ) - print(text) + generated_sequence = output_sequences.tolist()[ + encoded_prompt.size(1) : + ] # adapted to case where num_samples > 1 + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) + text = text[: text.find(args.stop_token) if args.stop_token else None] + + print(text) - if args.prompt: - break return text -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index a884abd0a2..3d8c812c2f 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera import logging import os +import warnings import torch from torch import nn +from tqdm import trange from .modeling_auto import AutoModel, AutoModelWithLMHead +from .modeling_utils import Sampler logger = logging.getLogger(__name__) @@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module): kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } kwargs_decoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy() @@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module): Indices of decoder input sequence tokens in the vocabulary. kwargs: (`optional`) Remaining dictionary of keyword arguments. """ - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. + kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) + + # Encode if needed (training, first prediction pass) + encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) + if encoder_hidden_states is None: + encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) + encoder_hidden_states = encoder_outputs[0] + else: + encoder_outputs = () + + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder) + + return decoder_outputs + encoder_outputs + + def decode( + self, + encoder_input_ids, + decoder_prompt_ids=None, + device=torch.device("cpu"), + length=10, + do_sample=False, + temperature=1.0, + k=9, + p=0., + repetition_penalty=1., + **kwargs + ): + """ Generic sequence generator for encoder-decoder models. + + For encoder-decoders the generation consists in: + - Performing a forward pass through the encoder once; + - Pass the encoder's hidden states to a decoding mechanism that + repeatedly calls the decoder to generate sequences. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) + The sequence to encode. + **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **device**: (`optional`) `torch.device` + The device on which the prompt_ids will be initialized if not provided. + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + if decoder_prompt_ids is None: + decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) + + # When the model does not have a LM head `get_output_embeddings` + # returns `None`. We use this mechanism to determine whether we + # should proceed with decoding or not. + if self.decoder.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.") + + # The followings checks that the decoder is on the same device as the one + # that is specified. It only works for models that fit on one GPU. + decoder_device = next(self.decoder.parameters()).device + if decoder_device != decoder_prompt_ids.device: + warnings.warn( + "The decoder is not on the same device as the prompt. Expected {}, got {}.".format( + decoder_prompt_ids.device, decoder_device + ) + ) + + kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs) + with torch.no_grad(): + encoder_outputs = self.encoder(encoder_input_ids, **kwargs) + encoder_hidden_states = encoder_outputs[0] + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + return self._greedy_decode_or_sample( + decoder_prompt_ids, length, sampler_config, **kwargs_decoder + ) + + def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder): + sampler = Sampler(**sampler_config) + with torch.no_grad(): + generated_sequence = prompt_ids + for _ in trange(length): + arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder) + outputs = self.decoder(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + @staticmethod + def prepare_model_kwargs(**kwargs): + """ Prepare the encoder and decoder's keyword arguments. + + Keyword arguments come in 3 flavors: + - encoder-specific (prefixed by `encoder_`) + - decoder-specific (prefixed by `decoder_`) + - those that apply to the model as whole. + + We let the specific kwargs override the common ones in case of + conflict. + """ kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } - kwargs_decoder = kwargs_common.copy() - kwargs_encoder = kwargs_common.copy() - kwargs_encoder.update( + decoder_kwargs = kwargs_common.copy() + encoder_kwargs = kwargs_common.copy() + encoder_kwargs.update( { argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") } ) - kwargs_decoder.update( + decoder_kwargs.update( { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } ) + decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None) - # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - if encoder_hidden_states is None: - encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[ - 0 - ] # output the last layer hidden state - else: - encoder_outputs = () - - # Decode - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get( - "attention_mask", None - ) - decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) - - return decoder_outputs + encoder_outputs + return encoder_kwargs, decoder_kwargs class Model2Model(PreTrainedEncoderDecoder): diff --git a/transformers/modeling_transfo_xl.py b/transformers/modeling_transfo_xl.py index a6a82f0dfe..473d07f733 100644 --- a/transformers/modeling_transfo_xl.py +++ b/transformers/modeling_transfo_xl.py @@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary from .configuration_transfo_xl import TransfoXLConfig -from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits +from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): outputs = [softmax_output, None] + outputs return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) + + def get_output_embeddings(self): + """ Double-check if you are using adaptive softmax. + """ + if self.sample_softmax > 0: + return self.out_layer + else: + return self.crit.out_layers[-1] diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 398172a88c..74038351fd 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -23,12 +23,14 @@ import json import logging import os from io import open +import warnings import six import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from tqdm import trange from .configuration_utils import PretrainedConfig from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME @@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module): def base_model(self): return getattr(self, self.base_model_prefix, self) + def decode(self, + prompt_ids=None, + device=torch.device('cpu'), + length=10, + do_sample=False, + temperature=1., + k=9, + p=0, + repetition_penalty=1, + **model_kwargs): + """ Generic sequence generator for single-stack models with a LM head. + + The method currently supports greedy decoding and sampling. See the + documentation of the `Sampler` class for more information about the + parameters related to sampling. + + Params: + **encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length) + The sequence to encode. + **decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length) + The sequence used as a prompt for the generation. If `None` the method initializes + it as an empty `torch.LongTensor` of shape (1,) + **device**: (`optional`) `torch.device` + The device on which the prompt_ids will be initialized if not provided. + **length**: (`optional`) int + The length of the sequence to be generated. + **do_sample**: (`optional`) bool + If set to `False` we use greedy decoding; otherwise sampling. + **temperature**: (`optional`) float + The value used to module the next token probabilities. + **k**: (`optional`) int + The parameter used for k-filtering. + **p**: (`optional`) float + The parameter for nucleus sampling. Must be between 0 and 1. + **repetition_penalty**: (`optional`) float + The parameter for repetition penalty. + """ + + if prompt_ids is None: + prompt_ids = torch.tensor([[]], dtype=torch.long, device=device) + + # When the model does not have a LM head `get_output_embeddings` + # returns `None`. We use this mechanism to determine whether we + # should proceed with decoding or not. + if self.get_output_embeddings() is None: + raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.") + + # The followings checks that the model is on the same device as the one + # that is specified. It only works for models that fit on one GPU. + model_device = next(self.parameters()).device + if model_device != prompt_ids.device: + warnings.warn( + "The model is not on the same device as the prompts. Expected {}, got {}.".format( + prompt_ids.device, model_device + ) + ) + + sampler_config = { + "k": k, + "p": p, + "do_sample": do_sample, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + } + return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs) + + def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs): + """ Generate text using greedy decoding or by sampling tokens.""" + sampler = Sampler(**sampler_config) + generated_sequence = prompt_ids + with torch.no_grad(): + for _ in trange(length): + arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs) + outputs = self(**arguments) + next_tokens_logits = outputs[0][:, -1, :] + next_tokens = sampler.get_one_token( + next_tokens_logits, generated_sequence + ) + generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1) + + return generated_sequence.squeeze(0) + + def _prepare_inputs_for_decoding(self, input_ids, **kwargs): + arguments = {"input_ids": input_ids} + arguments.update(kwargs) + return arguments + def get_input_embeddings(self): """ Get model's input embeddings """ @@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) + + +class Sampler(object): + r""" Sampler is used to generate sequences of ids from logit inputs. + + Greedy decoding, which consists in chosing the most probable token at each + step, is the default behaviour. Sampling with varying temperature, top_k + and nucleus filtering is also implemented. + + Attributes: + **device**: ``torch.device`` + Device on which the computations will be run. + **do_sample**: bool + Whether to sample or do greedy decoding. + **k**: int between 0 and vocab_size + Parameter for the top-k filtering + **p**: float between 0 and 1 + Parameter for the nucleus filtering + **temperature**: strictly positive float + Parameter used to modulate the distribution over ids. Low temperatures + put more emphasis on highly probably token while high temperatures tend + to smooth the probability distribution. + **repetition_penalty**: strictly postitive float + The penalty applied to repeating ids + """ + + def __init__( + self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0 + ): + self.k = k + self.p = p + self.do_sample = do_sample + self.temperature = temperature + self.repetition_penalty = repetition_penalty + + self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False + + if self.p > 1: + warnings.warn( + """You are trying to apply nucleus filtering with a value of p greater than 1 ({}). + However p is a probability and its value must lie between 0 and 1. In effect, no filtering + will be applied. If this is not the behavior you expect, change the value of p.""".format( + self.p + ) + ) + + def get_one_token(self, next_token_logits, past_sequence): + logits = self.apply_repetition_penalty(next_token_logits, past_sequence) + if self.do_sample: + logits = self.apply_temperature(logits) + logits = self.apply_top_k_filter(logits) + logits = self.apply_nucleus_filter(logits) + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return torch.argmax(logits, dim=-1).unsqueeze(-1) + + def apply_repetition_penalty(self, logits, past_sequence): + """ Apply a penalty to tokens that appear more than once in the + generated sequence. + + .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer + language model for controllable generation." arXiv preprint + arXiv:1909.05858 (2019). + """ + if self.do_apply_repetition_penalty: + generated_token_idx = set(past_sequence[0].tolist()) + for token_idx in generated_token_idx: + logits[0, token_idx] /= self.repetition_penalty + return logits + + def apply_temperature(self, logits): + """ Shape the tokens' distribution through temperature. The higher the value + of the temperature, the more skewed towards high probability events the + distribution is. + + .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. + MIT press, 2016. + """ + # when dividing a float by 0, torch returns inf which in turns breaks the + # multinomial with an error message that is not very helpful. It is better + # for the user to break the execution and explain why. + if self.temperature == 0: + raise ZeroDivisionError( + """You are trying to sample with a temperature equal to 0. + If you wanted to do greedy sampling, set instead `do_sample` to False. + Otherwise set the temperature to a value different from 0.""" + ) + return logits / self.temperature + + def apply_top_k_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically we select the set of size k such that + the sum of its items' probabilities is maximum. + + .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural + story generation." arXiv preprint arXiv:1805.04833 (2018). + """ + if self.k > 0: + vocabulary_size = logits.size(-1) + if self.k > vocabulary_size: + warnings.warn( + """You provided a value for k ({}) that is larger than the vocabulary size ({}). + We adjusted k's value to the vocabulary size; if that was what you intended to do + we recommend setting k to 0 instead. It this is not the behavior you expected, + choose a value of k that is smaller than the vocabulary size.""".format( + self.k, vocabulary_size + ) + ) + self.k = vocabulary_size + + indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None] + logits[indices_to_remove] = -float("Inf") + + return logits + + def apply_nucleus_filter(self, logits): + """ Use the probability distribution of the tokens to determine the set + to be sampled from. Specifically, choose the smallest set such that the + sum of its items' probabilities is greater than a number p in [0,1]. + + .. Holtzman, Ari, et al. "The curious case of neural text + degeneration." arXiv preprint arXiv:1904.09751 (2019). + """ + if self.p > 0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_probabilities = F.softmax(sorted_logits, dim=-1) + cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1) + + # Remove tokens with cumulative probability above the threshold, + # but keep the first token above the threshold. + sorted_indices_to_remove = cumulative_probabilities > self.p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = -float("Inf") + + return logits diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 257f0da394..295fff7943 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -646,7 +646,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, - lengths=lengths, + lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds) @@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): return outputs + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + mask_token = model_kwargs.pop("mask_token", None) + language = model_kwargs.pop("language", None) + input_ids = self._append_mask_token(input_ids, mask_token) + langs = self._create_language_embeddings(input_ids, language) + arguments = {"input_ids": input_ids, "langs": langs} + arguments.update(model_kwargs) + + return arguments + + @staticmethod + def _append_mask_token(sequence, mask_token_id): + """ Append a [MASK] token at the end of the sequence that the MLM model + is going to try to predict. + """ + if mask_token_id is not None: + tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long) + return torch.cat((sequence, tokens_to_append), dim=1) + + return sequence + + @staticmethod + def _create_language_embeddings(sequence, language): + if language is not None: + return torch.tensor([language] * sequence.shape[1]).view(1, -1) + return None + @add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 225e5b059b..2153923dd2 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): return outputs # return (loss), logits, (mems), (hidden states), (attentions) + def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs): + input_ids = self._add_dummy_token(input_ids) + perm_mask = self._create_perm_mask(input_ids) + target_mapping = self._create_target_mapping(input_ids) + arguments = { + "input_ids": input_ids, + "perm_mask": perm_mask, + "target_mapping": target_mapping, + } + return arguments + + @staticmethod + def _add_dummy_token(sequence): + dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long) + return torch.cat((sequence, dummy), dim=1) + + @staticmethod + def _create_perm_mask(sequence): + mask = torch.zeros( + (sequence.shape[0], sequence.shape[1], sequence.shape[1]), + dtype=torch.float, + ) + mask[:, :, -1] = 1.0 # Previous tokens don't see last token + return mask + + @staticmethod + def _create_target_mapping(sequence): + target_mapping = torch.zeros( + (sequence.shape[0], 1, sequence.shape[1]), + dtype=torch.float, + ) + target_mapping[0, 0, -1] = 1.0 # predict last token + return target_mapping + @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, diff --git a/transformers/tests/sampling_test.py b/transformers/tests/sampling_test.py new file mode 100644 index 0000000000..98cc23bf2b --- /dev/null +++ b/transformers/tests/sampling_test.py @@ -0,0 +1,213 @@ +# coding=utf-8 +import sys +import unittest + +import numpy as np +import pytest + +from transformers import is_torch_available + +if is_torch_available(): + import torch + + from transformers import ( + BertConfig, + BertModel, + GPT2Config, + GPT2LMHeadModel, + OpenAIGPTConfig, + OpenAIGPTLMHeadModel, + TransfoXLConfig, + TransfoXLLMHeadModel, + XLMConfig, + XLMWithLMHeadModel, + XLNetConfig, + XLNetLMHeadModel, + Model2Model, + ) + from transformers.modeling_utils import Sampler +else: + pytestmark = pytest.mark.skip("Require Torch") + + +class SamplerTest(unittest.TestCase): + def test_nucleus_sampling(self): + inf = -float("Inf") + test_cases = ( + { + "p": 0, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, 0.1, 0.2]), + }, + { + "p": 0.01, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, inf, inf]), + }, + { + "p": 1, + "logits": torch.tensor([0.3, 0.1, 0.2]), + "expected": torch.tensor([0.3, 0.1, 0.2]), + }, + { + "p": 0.2, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, inf]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, 0.2]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.1, 0.7, 0.2]), + "expected": torch.tensor([inf, 0.7, 0.2]), + }, + { + "p": 0.71, + "logits": torch.tensor([0.7, 0.2, 0.1]), + "expected": torch.tensor([0.7, 0.2, inf]), + }, + { + "p": 0.91, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + ) + for case in test_cases: + config = { + "do_sample": True, + "temperature": 1.0, + "k": 0, + "p": case["p"], + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + filtered_logits = sampler.apply_nucleus_filter(case["logits"]) + np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) + + def test_top_k_filter(self): + inf = -float("Inf") + test_cases = ( + { + "k": 0, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + { + "k": 1, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, inf]), + }, + { + "k": 2, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, inf, 0.2]), + }, + { + "k": 3, + "logits": torch.tensor([0.7, 0.1, 0.2]), + "expected": torch.tensor([0.7, 0.1, 0.2]), + }, + ) + for case in test_cases: + config = { + "do_sample": True, + "temperature": 1.0, + "k": case["k"], + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + filtered_logits = sampler.apply_top_k_filter(case["logits"]) + np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy()) + + @pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2") + def test_wrong_k_value(self): + case = {"k": 10, "vocab_size": 5} + config = { + "do_sample": True, + "temperature": 1.0, + "k": case["k"], + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0) + past_sequence = torch.tensor([]) + with self.assertWarns(UserWarning): + _ = sampler.get_one_token(next_token_logits, past_sequence) + + def test_zero_temperature(self): + temperature = 0 + config = { + "do_sample": True, + "temperature": temperature, + "k": 0, + "p": 0, + "repetition_penalty": 1.0, + } + sampler = Sampler(**config) + next_token_logits = torch.rand(10).unsqueeze(0) + past_sequence = torch.tensor([]) + with self.assertRaises(ZeroDivisionError): + _ = sampler.get_one_token(next_token_logits, past_sequence) + + +class SamplerSingleStackTest(unittest.TestCase): + def test_raises_exception_when_no_LM_head(self): + models = [BertModel(BertConfig())] + for model in models: + with self.assertRaises(AttributeError): + model.decode() + + @pytest.mark.slow + def test_forward_pass_and_output_length(self): + models = { + "XLNet": XLNetLMHeadModel(XLNetConfig()), + "XLM": XLMWithLMHeadModel(XLMConfig()), + "TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()), + "GPT2": GPT2LMHeadModel(GPT2Config()), + "GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()), + } + kwargs = { + "XLNet": {}, + "XLM": {"mask_token": 0}, + "TransfoXL": {}, + "GPT2": {}, + "GPT": {}, + } + prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) + generated_length = 5 + expected_length = 8 + + for name, model in models.items(): + kwargs_model = kwargs[name] + output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model) + self.assertEqual(len(output), expected_length) + + +class SamplerEncoderDecoderTest(unittest.TestCase): + @pytest.mark.slow + def test_forward_pass_and_output_length(self): + model = Model2Model.from_pretrained("bert-base-uncased") + + encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + prompt = torch.tensor([[1, 2, 3]], dtype=torch.long) + generated_length = 5 + expected_length = 8 + + output = model.decode( + encoder_input_ids, + decoder_prompt_ids=prompt, + k=2, + p=0.5, + repetition_penalty=2, + length=generated_length, + ) + self.assertEqual(len(output), expected_length) + + +if __name__ == "__main__": + unittest.main()