add greedy decoding and sampling
This commit is contained in:
@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import trange
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
||||
@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
|
||||
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||||
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
|
||||
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||||
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
|
||||
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
|
||||
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
|
||||
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
|
||||
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||||
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
|
||||
}
|
||||
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
@@ -75,81 +71,78 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size x vocabulary size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
#
|
||||
# Functions to prepare models' input
|
||||
#
|
||||
|
||||
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
|
||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||
generated = context
|
||||
with torch.no_grad():
|
||||
for _ in trange(length):
|
||||
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
|
||||
if args.temperature > 0.7:
|
||||
logger.info(
|
||||
"CTRL typically works better with lower temperatures (and lower top_k)."
|
||||
)
|
||||
|
||||
inputs = {'input_ids': generated}
|
||||
if is_xlnet:
|
||||
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
||||
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
||||
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
|
||||
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info(
|
||||
"WARNING! You are not starting your generation from a control code so you won't get good results"
|
||||
)
|
||||
return prompt_text, {}
|
||||
|
||||
if is_xlm_mlm and xlm_mask_token:
|
||||
# XLM MLM models are direct models (predict same token, not next token)
|
||||
# => need one additional dummy token in the input (will be masked and guessed)
|
||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
kwargs = {"language": None, "mask_token": None}
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
||||
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
|
||||
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
||||
for i in range(num_samples):
|
||||
for _ in set(generated[i].tolist()):
|
||||
next_token_logits[i, _] /= repetition_penalty
|
||||
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
if temperature == 0: # greedy sampling:
|
||||
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
|
||||
# Set the language
|
||||
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
|
||||
if hasattr(model.config, "lang2id") and use_lang_emb:
|
||||
available_languages = model.config.lang2id.keys()
|
||||
if args.xlm_language in available_languages:
|
||||
language = args.xlm_language
|
||||
else:
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
return generated
|
||||
language = None
|
||||
while language not in available_languages:
|
||||
language = input(
|
||||
"Using XLM. Select language in "
|
||||
+ str(list(available_languages))
|
||||
+ " >>> "
|
||||
)
|
||||
kwargs["language"] = tokenizer.lang2id[language]
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token
|
||||
is_xlm_mlm = "mlm" in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
kwargs["mask_token"] = tokenizer.mask_token_id
|
||||
|
||||
return prompt_text, kwargs
|
||||
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
|
||||
|
||||
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
|
||||
|
||||
PREPROCESSING_FUNCTIONS = {
|
||||
"ctrl": prepare_ctrl_input,
|
||||
"xlm": prepare_xlm_input,
|
||||
"xlnet": prepare_xlnet_input,
|
||||
"transfo-xl": prepare_transfoxl_input,
|
||||
}
|
||||
|
||||
|
||||
def adjust_length_to_model(length, max_sequence_length):
|
||||
if length < 0 and max_sequence_length > 0:
|
||||
length = max_sequence_length
|
||||
elif 0 < max_sequence_length < length:
|
||||
length = max_sequence_length # No generation bigger than model size
|
||||
elif length < 0:
|
||||
length = MAX_LENGTH # avoid infinite loop
|
||||
return length
|
||||
|
||||
|
||||
def main():
|
||||
@@ -157,104 +150,81 @@ def main():
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--padding_text", type=str, default="")
|
||||
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
parser.add_argument("--num_samples", type=int, default=1)
|
||||
parser.add_argument("--temperature", type=float, default=1.0,
|
||||
help="temperature of 0 implies greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0,
|
||||
help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument("--top_k", type=int, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--stop_token', type=str, default=None,
|
||||
help="Token at which text generation is stopped")
|
||||
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
|
||||
parser.add_argument("--k", type=int, default=0)
|
||||
parser.add_argument("--p", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
|
||||
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
)
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
set_seed(args)
|
||||
|
||||
# Initialize the model and tokenizer
|
||||
try:
|
||||
args.model_type = args.model_type.lower()
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
except KeyError as ke:
|
||||
raise ke(
|
||||
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
if args.length < 0 and model.config.max_position_embeddings > 0:
|
||||
args.length = model.config.max_position_embeddings
|
||||
elif 0 < model.config.max_position_embeddings < args.length:
|
||||
args.length = model.config.max_position_embeddings # No generation bigger than model size
|
||||
elif args.length < 0:
|
||||
args.length = MAX_LENGTH # avoid infinite loop
|
||||
|
||||
args.length = adjust_length_to_model(
|
||||
args.length, max_sequence_length=model.config.max_position_embeddings
|
||||
)
|
||||
logger.info(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7:
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
xlm_lang = None
|
||||
# XLM Language usage detailed in the issues #1414
|
||||
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
|
||||
and model.config.use_lang_emb:
|
||||
if args.xlm_lang:
|
||||
language = args.xlm_lang
|
||||
else:
|
||||
language = None
|
||||
while language not in tokenizer.lang2id.keys():
|
||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||
xlm_lang = tokenizer.lang2id[language]
|
||||
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
xlm_mask_token = tokenizer.mask_token_id
|
||||
else:
|
||||
xlm_mask_token = None
|
||||
# Different models need different input formatting and/or extra arguments
|
||||
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||
model_kwargs = {}
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
|
||||
|
||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
|
||||
if args.model_type == "ctrl":
|
||||
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||
out = sample_sequence(
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
num_samples=args.num_samples,
|
||||
output_sequences = model.decode(
|
||||
prompt_ids=encoded_prompt,
|
||||
length=args.length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
k=args.k,
|
||||
p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
is_xlm_mlm=is_xlm_mlm,
|
||||
xlm_mask_token=xlm_mask_token,
|
||||
xlm_lang=xlm_lang,
|
||||
device=args.device,
|
||||
**model_kwargs,
|
||||
)
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for o in out:
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
||||
|
||||
generated_sequence = output_sequences.tolist()[
|
||||
encoded_prompt.size(1) :
|
||||
] # adapted to case where num_samples > 1
|
||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
|
||||
print(text)
|
||||
|
||||
if args.prompt:
|
||||
break
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
from .modeling_utils import Sampler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_")
|
||||
and not argument.startswith("decoder_")
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
@@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments.
|
||||
"""
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
def decode(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_prompt_ids=None,
|
||||
device=torch.device("cpu"),
|
||||
length=10,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
k=9,
|
||||
p=0.,
|
||||
repetition_penalty=1.,
|
||||
**kwargs
|
||||
):
|
||||
""" Generic sequence generator for encoder-decoder models.
|
||||
|
||||
For encoder-decoders the generation consists in:
|
||||
- Performing a forward pass through the encoder once;
|
||||
- Pass the encoder's hidden states to a decoding mechanism that
|
||||
repeatedly calls the decoder to generate sequences.
|
||||
|
||||
The method currently supports greedy decoding and sampling. See the
|
||||
documentation of the `Sampler` class for more information about the
|
||||
parameters related to sampling.
|
||||
|
||||
Params:
|
||||
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence to encode.
|
||||
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**device**: (`optional`) `torch.device`
|
||||
The device on which the prompt_ids will be initialized if not provided.
|
||||
**length**: (`optional`) int
|
||||
The length of the sequence to be generated.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**k**: (`optional`) int
|
||||
The parameter used for k-filtering.
|
||||
**p**: (`optional`) float
|
||||
The parameter for nucleus sampling. Must be between 0 and 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty.
|
||||
"""
|
||||
if decoder_prompt_ids is None:
|
||||
decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
|
||||
|
||||
# When the model does not have a LM head `get_output_embeddings`
|
||||
# returns `None`. We use this mechanism to determine whether we
|
||||
# should proceed with decoding or not.
|
||||
if self.decoder.get_output_embeddings() is None:
|
||||
raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.")
|
||||
|
||||
# The followings checks that the decoder is on the same device as the one
|
||||
# that is specified. It only works for models that fit on one GPU.
|
||||
decoder_device = next(self.decoder.parameters()).device
|
||||
if decoder_device != decoder_prompt_ids.device:
|
||||
warnings.warn(
|
||||
"The decoder is not on the same device as the prompt. Expected {}, got {}.".format(
|
||||
decoder_prompt_ids.device, decoder_device
|
||||
)
|
||||
)
|
||||
|
||||
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
|
||||
sampler_config = {
|
||||
"k": k,
|
||||
"p": p,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
return self._greedy_decode_or_sample(
|
||||
decoder_prompt_ids, length, sampler_config, **kwargs_decoder
|
||||
)
|
||||
|
||||
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder):
|
||||
sampler = Sampler(**sampler_config)
|
||||
with torch.no_grad():
|
||||
generated_sequence = prompt_ids
|
||||
for _ in trange(length):
|
||||
arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder)
|
||||
outputs = self.decoder(**arguments)
|
||||
next_tokens_logits = outputs[0][:, -1, :]
|
||||
next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence)
|
||||
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
|
||||
|
||||
return generated_sequence.squeeze(0)
|
||||
|
||||
@staticmethod
|
||||
def prepare_model_kwargs(**kwargs):
|
||||
""" Prepare the encoder and decoder's keyword arguments.
|
||||
|
||||
Keyword arguments come in 3 flavors:
|
||||
- encoder-specific (prefixed by `encoder_`)
|
||||
- decoder-specific (prefixed by `decoder_`)
|
||||
- those that apply to the model as whole.
|
||||
|
||||
We let the specific kwargs override the common ones in case of
|
||||
conflict.
|
||||
"""
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_")
|
||||
and not argument.startswith("decoder_")
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
kwargs_encoder.update(
|
||||
decoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs.update(
|
||||
{
|
||||
argument[len("encoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
)
|
||||
kwargs_decoder.update(
|
||||
decoder_kwargs.update(
|
||||
{
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
)
|
||||
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[
|
||||
0
|
||||
] # output the last layer hidden state
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
|
||||
"attention_mask", None
|
||||
)
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
return encoder_kwargs, decoder_kwargs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
|
||||
@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
||||
from .configuration_transfo_xl import TransfoXLConfig
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
outputs = [softmax_output, None] + outputs
|
||||
|
||||
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Double-check if you are using adaptive softmax.
|
||||
"""
|
||||
if self.sample_softmax > 0:
|
||||
return self.out_layer
|
||||
else:
|
||||
return self.crit.out_layers[-1]
|
||||
|
||||
@@ -23,12 +23,14 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from io import open
|
||||
import warnings
|
||||
|
||||
import six
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
from tqdm import trange
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||
@@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module):
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def decode(self,
|
||||
prompt_ids=None,
|
||||
device=torch.device('cpu'),
|
||||
length=10,
|
||||
do_sample=False,
|
||||
temperature=1.,
|
||||
k=9,
|
||||
p=0,
|
||||
repetition_penalty=1,
|
||||
**model_kwargs):
|
||||
""" Generic sequence generator for single-stack models with a LM head.
|
||||
|
||||
The method currently supports greedy decoding and sampling. See the
|
||||
documentation of the `Sampler` class for more information about the
|
||||
parameters related to sampling.
|
||||
|
||||
Params:
|
||||
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence to encode.
|
||||
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**device**: (`optional`) `torch.device`
|
||||
The device on which the prompt_ids will be initialized if not provided.
|
||||
**length**: (`optional`) int
|
||||
The length of the sequence to be generated.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**k**: (`optional`) int
|
||||
The parameter used for k-filtering.
|
||||
**p**: (`optional`) float
|
||||
The parameter for nucleus sampling. Must be between 0 and 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty.
|
||||
"""
|
||||
|
||||
if prompt_ids is None:
|
||||
prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
|
||||
|
||||
# When the model does not have a LM head `get_output_embeddings`
|
||||
# returns `None`. We use this mechanism to determine whether we
|
||||
# should proceed with decoding or not.
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
|
||||
|
||||
# The followings checks that the model is on the same device as the one
|
||||
# that is specified. It only works for models that fit on one GPU.
|
||||
model_device = next(self.parameters()).device
|
||||
if model_device != prompt_ids.device:
|
||||
warnings.warn(
|
||||
"The model is not on the same device as the prompts. Expected {}, got {}.".format(
|
||||
prompt_ids.device, model_device
|
||||
)
|
||||
)
|
||||
|
||||
sampler_config = {
|
||||
"k": k,
|
||||
"p": p,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs)
|
||||
|
||||
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs):
|
||||
""" Generate text using greedy decoding or by sampling tokens."""
|
||||
sampler = Sampler(**sampler_config)
|
||||
generated_sequence = prompt_ids
|
||||
with torch.no_grad():
|
||||
for _ in trange(length):
|
||||
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
|
||||
outputs = self(**arguments)
|
||||
next_tokens_logits = outputs[0][:, -1, :]
|
||||
next_tokens = sampler.get_one_token(
|
||||
next_tokens_logits, generated_sequence
|
||||
)
|
||||
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
|
||||
|
||||
return generated_sequence.squeeze(0)
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **kwargs):
|
||||
arguments = {"input_ids": input_ids}
|
||||
arguments.update(kwargs)
|
||||
return arguments
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
@@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
|
||||
|
||||
class Sampler(object):
|
||||
r""" Sampler is used to generate sequences of ids from logit inputs.
|
||||
|
||||
Greedy decoding, which consists in chosing the most probable token at each
|
||||
step, is the default behaviour. Sampling with varying temperature, top_k
|
||||
and nucleus filtering is also implemented.
|
||||
|
||||
Attributes:
|
||||
**device**: ``torch.device``
|
||||
Device on which the computations will be run.
|
||||
**do_sample**: bool
|
||||
Whether to sample or do greedy decoding.
|
||||
**k**: int between 0 and vocab_size
|
||||
Parameter for the top-k filtering
|
||||
**p**: float between 0 and 1
|
||||
Parameter for the nucleus filtering
|
||||
**temperature**: strictly positive float
|
||||
Parameter used to modulate the distribution over ids. Low temperatures
|
||||
put more emphasis on highly probably token while high temperatures tend
|
||||
to smooth the probability distribution.
|
||||
**repetition_penalty**: strictly postitive float
|
||||
The penalty applied to repeating ids
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
|
||||
):
|
||||
self.k = k
|
||||
self.p = p
|
||||
self.do_sample = do_sample
|
||||
self.temperature = temperature
|
||||
self.repetition_penalty = repetition_penalty
|
||||
|
||||
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
|
||||
|
||||
if self.p > 1:
|
||||
warnings.warn(
|
||||
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
|
||||
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
|
||||
will be applied. If this is not the behavior you expect, change the value of p.""".format(
|
||||
self.p
|
||||
)
|
||||
)
|
||||
|
||||
def get_one_token(self, next_token_logits, past_sequence):
|
||||
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
|
||||
if self.do_sample:
|
||||
logits = self.apply_temperature(logits)
|
||||
logits = self.apply_top_k_filter(logits)
|
||||
logits = self.apply_nucleus_filter(logits)
|
||||
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
return torch.argmax(logits, dim=-1).unsqueeze(-1)
|
||||
|
||||
def apply_repetition_penalty(self, logits, past_sequence):
|
||||
""" Apply a penalty to tokens that appear more than once in the
|
||||
generated sequence.
|
||||
|
||||
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
|
||||
language model for controllable generation." arXiv preprint
|
||||
arXiv:1909.05858 (2019).
|
||||
"""
|
||||
if self.do_apply_repetition_penalty:
|
||||
generated_token_idx = set(past_sequence[0].tolist())
|
||||
for token_idx in generated_token_idx:
|
||||
logits[0, token_idx] /= self.repetition_penalty
|
||||
return logits
|
||||
|
||||
def apply_temperature(self, logits):
|
||||
""" Shape the tokens' distribution through temperature. The higher the value
|
||||
of the temperature, the more skewed towards high probability events the
|
||||
distribution is.
|
||||
|
||||
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
|
||||
MIT press, 2016.
|
||||
"""
|
||||
# when dividing a float by 0, torch returns inf which in turns breaks the
|
||||
# multinomial with an error message that is not very helpful. It is better
|
||||
# for the user to break the execution and explain why.
|
||||
if self.temperature == 0:
|
||||
raise ZeroDivisionError(
|
||||
"""You are trying to sample with a temperature equal to 0.
|
||||
If you wanted to do greedy sampling, set instead `do_sample` to False.
|
||||
Otherwise set the temperature to a value different from 0."""
|
||||
)
|
||||
return logits / self.temperature
|
||||
|
||||
def apply_top_k_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically we select the set of size k such that
|
||||
the sum of its items' probabilities is maximum.
|
||||
|
||||
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
|
||||
story generation." arXiv preprint arXiv:1805.04833 (2018).
|
||||
"""
|
||||
if self.k > 0:
|
||||
vocabulary_size = logits.size(-1)
|
||||
if self.k > vocabulary_size:
|
||||
warnings.warn(
|
||||
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
|
||||
We adjusted k's value to the vocabulary size; if that was what you intended to do
|
||||
we recommend setting k to 0 instead. It this is not the behavior you expected,
|
||||
choose a value of k that is smaller than the vocabulary size.""".format(
|
||||
self.k, vocabulary_size
|
||||
)
|
||||
)
|
||||
self.k = vocabulary_size
|
||||
|
||||
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
||||
def apply_nucleus_filter(self, logits):
|
||||
""" Use the probability distribution of the tokens to determine the set
|
||||
to be sampled from. Specifically, choose the smallest set such that the
|
||||
sum of its items' probabilities is greater than a number p in [0,1].
|
||||
|
||||
.. Holtzman, Ari, et al. "The curious case of neural text
|
||||
degeneration." arXiv preprint arXiv:1904.09751 (2019).
|
||||
"""
|
||||
if self.p > 0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold,
|
||||
# but keep the first token above the threshold.
|
||||
sorted_indices_to_remove = cumulative_probabilities > self.p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
|
||||
return logits
|
||||
|
||||
@@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
mask_token = model_kwargs.pop("mask_token", None)
|
||||
language = model_kwargs.pop("language", None)
|
||||
input_ids = self._append_mask_token(input_ids, mask_token)
|
||||
langs = self._create_language_embeddings(input_ids, language)
|
||||
arguments = {"input_ids": input_ids, "langs": langs}
|
||||
arguments.update(model_kwargs)
|
||||
|
||||
return arguments
|
||||
|
||||
@staticmethod
|
||||
def _append_mask_token(sequence, mask_token_id):
|
||||
""" Append a [MASK] token at the end of the sequence that the MLM model
|
||||
is going to try to predict.
|
||||
"""
|
||||
if mask_token_id is not None:
|
||||
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
|
||||
return torch.cat((sequence, tokens_to_append), dim=1)
|
||||
|
||||
return sequence
|
||||
|
||||
@staticmethod
|
||||
def _create_language_embeddings(sequence, language):
|
||||
if language is not None:
|
||||
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
|
||||
return None
|
||||
|
||||
|
||||
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
@@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
input_ids = self._add_dummy_token(input_ids)
|
||||
perm_mask = self._create_perm_mask(input_ids)
|
||||
target_mapping = self._create_target_mapping(input_ids)
|
||||
arguments = {
|
||||
"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
}
|
||||
return arguments
|
||||
|
||||
@staticmethod
|
||||
def _add_dummy_token(sequence):
|
||||
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
|
||||
return torch.cat((sequence, dummy), dim=1)
|
||||
|
||||
@staticmethod
|
||||
def _create_perm_mask(sequence):
|
||||
mask = torch.zeros(
|
||||
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
|
||||
dtype=torch.float,
|
||||
)
|
||||
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _create_target_mapping(sequence):
|
||||
target_mapping = torch.zeros(
|
||||
(sequence.shape[0], 1, sequence.shape[1]),
|
||||
dtype=torch.float,
|
||||
)
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
return target_mapping
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
213
transformers/tests/sampling_test.py
Normal file
213
transformers/tests/sampling_test.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# coding=utf-8
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTLMHeadModel,
|
||||
TransfoXLConfig,
|
||||
TransfoXLLMHeadModel,
|
||||
XLMConfig,
|
||||
XLMWithLMHeadModel,
|
||||
XLNetConfig,
|
||||
XLNetLMHeadModel,
|
||||
Model2Model,
|
||||
)
|
||||
from transformers.modeling_utils import Sampler
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
class SamplerTest(unittest.TestCase):
|
||||
def test_nucleus_sampling(self):
|
||||
inf = -float("Inf")
|
||||
test_cases = (
|
||||
{
|
||||
"p": 0,
|
||||
"logits": torch.tensor([0.3, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.3, 0.1, 0.2]),
|
||||
},
|
||||
{
|
||||
"p": 0.01,
|
||||
"logits": torch.tensor([0.3, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.3, inf, inf]),
|
||||
},
|
||||
{
|
||||
"p": 1,
|
||||
"logits": torch.tensor([0.3, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.3, 0.1, 0.2]),
|
||||
},
|
||||
{
|
||||
"p": 0.2,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, inf, inf]),
|
||||
},
|
||||
{
|
||||
"p": 0.71,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, inf, 0.2]),
|
||||
},
|
||||
{
|
||||
"p": 0.71,
|
||||
"logits": torch.tensor([0.1, 0.7, 0.2]),
|
||||
"expected": torch.tensor([inf, 0.7, 0.2]),
|
||||
},
|
||||
{
|
||||
"p": 0.71,
|
||||
"logits": torch.tensor([0.7, 0.2, 0.1]),
|
||||
"expected": torch.tensor([0.7, 0.2, inf]),
|
||||
},
|
||||
{
|
||||
"p": 0.91,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, 0.1, 0.2]),
|
||||
},
|
||||
)
|
||||
for case in test_cases:
|
||||
config = {
|
||||
"do_sample": True,
|
||||
"temperature": 1.0,
|
||||
"k": 0,
|
||||
"p": case["p"],
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
sampler = Sampler(**config)
|
||||
filtered_logits = sampler.apply_nucleus_filter(case["logits"])
|
||||
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
|
||||
|
||||
def test_top_k_filter(self):
|
||||
inf = -float("Inf")
|
||||
test_cases = (
|
||||
{
|
||||
"k": 0,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, 0.1, 0.2]),
|
||||
},
|
||||
{
|
||||
"k": 1,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, inf, inf]),
|
||||
},
|
||||
{
|
||||
"k": 2,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, inf, 0.2]),
|
||||
},
|
||||
{
|
||||
"k": 3,
|
||||
"logits": torch.tensor([0.7, 0.1, 0.2]),
|
||||
"expected": torch.tensor([0.7, 0.1, 0.2]),
|
||||
},
|
||||
)
|
||||
for case in test_cases:
|
||||
config = {
|
||||
"do_sample": True,
|
||||
"temperature": 1.0,
|
||||
"k": case["k"],
|
||||
"p": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
sampler = Sampler(**config)
|
||||
filtered_logits = sampler.apply_top_k_filter(case["logits"])
|
||||
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2")
|
||||
def test_wrong_k_value(self):
|
||||
case = {"k": 10, "vocab_size": 5}
|
||||
config = {
|
||||
"do_sample": True,
|
||||
"temperature": 1.0,
|
||||
"k": case["k"],
|
||||
"p": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
sampler = Sampler(**config)
|
||||
next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0)
|
||||
past_sequence = torch.tensor([])
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = sampler.get_one_token(next_token_logits, past_sequence)
|
||||
|
||||
def test_zero_temperature(self):
|
||||
temperature = 0
|
||||
config = {
|
||||
"do_sample": True,
|
||||
"temperature": temperature,
|
||||
"k": 0,
|
||||
"p": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
sampler = Sampler(**config)
|
||||
next_token_logits = torch.rand(10).unsqueeze(0)
|
||||
past_sequence = torch.tensor([])
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
_ = sampler.get_one_token(next_token_logits, past_sequence)
|
||||
|
||||
|
||||
class SamplerSingleStackTest(unittest.TestCase):
|
||||
def test_raises_exception_when_no_LM_head(self):
|
||||
models = [BertModel(BertConfig())]
|
||||
for model in models:
|
||||
with self.assertRaises(AttributeError):
|
||||
model.decode()
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_forward_pass_and_output_length(self):
|
||||
models = {
|
||||
"XLNet": XLNetLMHeadModel(XLNetConfig()),
|
||||
"XLM": XLMWithLMHeadModel(XLMConfig()),
|
||||
"TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()),
|
||||
"GPT2": GPT2LMHeadModel(GPT2Config()),
|
||||
"GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()),
|
||||
}
|
||||
kwargs = {
|
||||
"XLNet": {},
|
||||
"XLM": {"mask_token": 0},
|
||||
"TransfoXL": {},
|
||||
"GPT2": {},
|
||||
"GPT": {},
|
||||
}
|
||||
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
generated_length = 5
|
||||
expected_length = 8
|
||||
|
||||
for name, model in models.items():
|
||||
kwargs_model = kwargs[name]
|
||||
output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model)
|
||||
self.assertEqual(len(output), expected_length)
|
||||
|
||||
|
||||
class SamplerEncoderDecoderTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_forward_pass_and_output_length(self):
|
||||
model = Model2Model.from_pretrained("bert-base-uncased")
|
||||
|
||||
encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
generated_length = 5
|
||||
expected_length = 8
|
||||
|
||||
output = model.decode(
|
||||
encoder_input_ids,
|
||||
decoder_prompt_ids=prompt,
|
||||
k=2,
|
||||
p=0.5,
|
||||
repetition_penalty=2,
|
||||
length=generated_length,
|
||||
)
|
||||
self.assertEqual(len(output), expected_length)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user