Black 20 release
This commit is contained in:
@@ -329,7 +329,13 @@ class TFGenerationMixin:
|
||||
if self.config.is_encoder_decoder:
|
||||
|
||||
# create empty decoder_input_ids
|
||||
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
|
||||
input_ids = (
|
||||
tf.ones(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
dtype=tf.int32,
|
||||
)
|
||||
* decoder_start_token_id
|
||||
)
|
||||
cur_len = 1
|
||||
|
||||
assert (
|
||||
@@ -422,8 +428,8 @@ class TFGenerationMixin:
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
|
||||
# length of generated sentences / unfinished sentences
|
||||
@@ -587,8 +593,7 @@ class TFGenerationMixin:
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
@@ -960,14 +965,14 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
logits_shape = shape_list(logits)
|
||||
|
||||
@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
|
||||
sorted_indices_to_remove = tf.concat(
|
||||
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
|
||||
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]],
|
||||
-1,
|
||||
)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
|
||||
@@ -1027,9 +1033,9 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
|
||||
|
||||
def sample_without_replacement(logits, num_samples):
|
||||
"""
|
||||
categorical sampling witouth replacement is currently not implemented
|
||||
the gumbel-max trick will do for now
|
||||
see https://github.com/tensorflow/tensorflow/issues/9260 for more info
|
||||
categorical sampling witouth replacement is currently not implemented
|
||||
the gumbel-max trick will do for now
|
||||
see https://github.com/tensorflow/tensorflow/issues/9260 for more info
|
||||
"""
|
||||
z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))
|
||||
_, indices = tf.nn.top_k(logits + z, num_samples)
|
||||
|
||||
Reference in New Issue
Block a user