Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -384,6 +384,432 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
return model
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||
return {"inputs": inputs}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
return True
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids=None,
|
||||
max_length=None,
|
||||
do_sample=True,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_ids=None,
|
||||
length_penalty=None,
|
||||
num_return_sequences=None,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
|
||||
Adapted in part from `Facebook's XLM beam search code`_.
|
||||
|
||||
.. _`Facebook's XLM beam search code`:
|
||||
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||
|
||||
do_sample: (`optional`) bool
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
||||
temperature: (`optional`) float
|
||||
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
|
||||
|
||||
top_k: (`optional`) int
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||
|
||||
top_p: (`optional`) float
|
||||
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||
|
||||
eos_token_ids: (`optional`) int or list of int
|
||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
num_return_sequences: (`optional`) int
|
||||
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
Return:
|
||||
|
||||
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||
for i in range(3): # 3 output sequences were generated
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
|
||||
for i in range(3): # 3 output sequences were generated
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
"""
|
||||
|
||||
# We cannot generate if the model does not have a LM head
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError(
|
||||
"You tried to generate sequences with a model that does not have a LM Head."
|
||||
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
|
||||
)
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictely positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
assert input_ids is not None or (
|
||||
isinstance(bos_token_id, int) and bos_token_id >= 0
|
||||
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
|
||||
assert pad_token_id is None or (
|
||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||
), "`pad_token_id` should be a positive integer."
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
), "`num_return_sequences` should be a strictely positive integer."
|
||||
|
||||
if input_ids is None:
|
||||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
|
||||
"you should either supply a context to complete as `input_ids` input "
|
||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||
)
|
||||
input_ids = tf.fill((batch_size, 1), bos_token_id)
|
||||
else:
|
||||
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
)
|
||||
pad_token_id = eos_token_ids[0]
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = shape_list(input_ids)[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1:
|
||||
# Expand input to num return sequences
|
||||
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
|
||||
effective_batch_size = batch_size * num_return_sequences
|
||||
input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len))
|
||||
else:
|
||||
effective_batch_size = batch_size
|
||||
|
||||
if num_beams > 1:
|
||||
output = self._generate_beam_search(
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _generate_no_beam_search(
|
||||
self,
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits):
|
||||
# create logit penalties for already seen input_ids
|
||||
token_penalties = np.ones(shape_list(logits))
|
||||
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
||||
for i, prev_input_id in enumerate(prev_input_ids):
|
||||
logit_penalized = logits[i].numpy()[prev_input_id]
|
||||
# if previous logit score is < 0 then multiply repetition penalty else divide
|
||||
logit_penalized[logit_penalized < 0] = repetition_penalty
|
||||
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalized)
|
||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||
|
||||
past = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
# Sample
|
||||
next_token = tf.squeeze(
|
||||
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
|
||||
)
|
||||
else:
|
||||
# Greedy decoding
|
||||
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
||||
|
||||
# update generations and finished sentences
|
||||
if eos_token_ids is not None:
|
||||
# pad finished sentences if eos_token_ids exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||
|
||||
if eos_token_ids is not None:
|
||||
for eos_token_id in eos_token_ids:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
||||
)
|
||||
sent_lengths = (
|
||||
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
||||
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
||||
)
|
||||
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||
break
|
||||
|
||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||
min_sent_length = tf.math.reduce_min(sent_lengths)
|
||||
max_sent_length = tf.math.reduce_max(sent_lengths)
|
||||
if min_sent_length != max_sent_length:
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
||||
# finished sents are filled with pad_token
|
||||
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
|
||||
|
||||
# create length masks for tf.where operation
|
||||
broad_casted_sent_lengths = tf.broadcast_to(
|
||||
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
||||
)
|
||||
broad_casted_range = tf.transpose(
|
||||
tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
|
||||
)
|
||||
|
||||
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
||||
else:
|
||||
decoded = input_ids
|
||||
|
||||
return decoded
|
||||
|
||||
def _generate_beam_search(
|
||||
self,
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
logits_shape = shape_list(logits)
|
||||
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
|
||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_indices = tf.argsort(logits, direction="DESCENDING")
|
||||
sorted_logits = tf.gather(
|
||||
logits, sorted_indices, axis=-1, batch_dims=1
|
||||
) # expects logits to be of dim (batch_size, vocab_size)
|
||||
|
||||
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
|
||||
if min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove = tf.concat(
|
||||
[
|
||||
tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
|
||||
sorted_indices_to_remove[:, min_tokens_to_keep:],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = tf.concat(
|
||||
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
|
||||
)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
||||
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
|
||||
return logits
|
||||
|
||||
|
||||
def scatter_values_on_batch_indices(values, batch_indices):
|
||||
shape = shape_list(batch_indices)
|
||||
# broadcast batch dim to shape
|
||||
broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
|
||||
# transform batch_indices to pair_indices
|
||||
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
|
||||
# scatter values to pair indices
|
||||
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
|
||||
|
||||
|
||||
def set_tensor_by_indices_to_value(tensor, indices, value):
|
||||
# create value_tensor since tensor value assignment is not possible in TF
|
||||
value_tensor = tf.zeros_like(tensor) + value
|
||||
return tf.where(indices, value_tensor, tensor)
|
||||
|
||||
|
||||
class TFConv1D(tf.keras.layers.Layer):
|
||||
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user