refactor variable naming and improve tf generate in line with torch generate
This commit is contained in:
@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
|
min_length=None,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
early_stopping=False,
|
early_stopping=False,
|
||||||
num_beams=None,
|
num_beams=None,
|
||||||
@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
pad_token_id=None,
|
pad_token_id=None,
|
||||||
eos_token_ids=None,
|
eos_token_ids=None,
|
||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
|
no_repeat_ngram_size=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
|
attention_mask=None,
|
||||||
):
|
):
|
||||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||||
and beam-search.
|
and beam-search.
|
||||||
@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
@@ -575,6 +579,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
no_repeat_ngram_size = (
|
||||||
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
|
)
|
||||||
num_return_sequences = (
|
num_return_sequences = (
|
||||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||||
)
|
)
|
||||||
@@ -587,6 +594,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
eos_token_ids = [eos_token_ids]
|
eos_token_ids = [eos_token_ids]
|
||||||
|
|
||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
|
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||||
@@ -631,6 +639,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
num_beams >= num_return_sequences
|
num_beams >= num_return_sequences
|
||||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||||
|
|
||||||
|
# create attention mask if necessary
|
||||||
|
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
||||||
|
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
|
||||||
|
elif attention_mask is None:
|
||||||
|
attention_mask = tf.ones_like(input_ids)
|
||||||
|
|
||||||
if pad_token_id is None and eos_token_ids is not None:
|
if pad_token_id is None and eos_token_ids is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||||
@@ -655,42 +670,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
input_ids = tf.broadcast_to(
|
input_ids = tf.broadcast_to(
|
||||||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||||
)
|
)
|
||||||
|
attention_mask = tf.broadcast_to(
|
||||||
|
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||||
|
)
|
||||||
input_ids = tf.reshape(
|
input_ids = tf.reshape(
|
||||||
input_ids, (effective_batch_size * num_beams, input_ids_len)
|
input_ids, (effective_batch_size * num_beams, input_ids_len)
|
||||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
attention_mask = tf.reshape(
|
||||||
|
attention_mask, (effective_batch_size * num_beams, input_ids_len)
|
||||||
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
|
||||||
if num_beams > 1:
|
if num_beams > 1:
|
||||||
output = self._generate_beam_search(
|
output = self._generate_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len=cur_len,
|
||||||
max_length,
|
max_length=max_length,
|
||||||
do_sample,
|
min_length=min_length,
|
||||||
early_stopping,
|
do_sample=do_sample,
|
||||||
temperature,
|
early_stopping=early_stopping,
|
||||||
top_k,
|
temperature=temperature,
|
||||||
top_p,
|
top_k=top_k,
|
||||||
repetition_penalty,
|
top_p=top_p,
|
||||||
pad_token_id,
|
repetition_penalty=repetition_penalty,
|
||||||
eos_token_ids,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
effective_batch_size,
|
pad_token_id=pad_token_id,
|
||||||
num_return_sequences,
|
eos_token_ids=eos_token_ids,
|
||||||
length_penalty,
|
batch_size=effective_batch_size,
|
||||||
num_beams,
|
num_return_sequences=num_return_sequences,
|
||||||
vocab_size,
|
length_penalty=length_penalty,
|
||||||
|
num_beams=num_beams,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(
|
output = self._generate_no_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len=cur_len,
|
||||||
max_length,
|
max_length=max_length,
|
||||||
do_sample,
|
min_length=min_length,
|
||||||
temperature,
|
do_sample=do_sample,
|
||||||
top_k,
|
temperature=temperature,
|
||||||
top_p,
|
top_k=top_k,
|
||||||
repetition_penalty,
|
top_p=top_p,
|
||||||
pad_token_id,
|
repetition_penalty=repetition_penalty,
|
||||||
eos_token_ids,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
effective_batch_size,
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_ids=eos_token_ids,
|
||||||
|
batch_size=effective_batch_size,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -700,14 +728,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len,
|
||||||
max_length,
|
max_length,
|
||||||
|
min_length,
|
||||||
do_sample,
|
do_sample,
|
||||||
temperature,
|
temperature,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
vocab_size,
|
||||||
|
attention_mask,
|
||||||
):
|
):
|
||||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
@@ -720,7 +752,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
past = None
|
past = None
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs[0][:, -1, :]
|
||||||
|
|
||||||
@@ -735,6 +767,33 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
)
|
)
|
||||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||||
|
|
||||||
|
if no_repeat_ngram_size > 0:
|
||||||
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
|
# create banned_tokens boolean mask
|
||||||
|
banned_tokens_indices_mask = []
|
||||||
|
for banned_tokens_slice in banned_tokens:
|
||||||
|
banned_tokens_indices_mask.append(
|
||||||
|
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = set_tensor_by_indices_to_value(
|
||||||
|
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# set eos token prob to zero if min_length is not reached
|
||||||
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
|
# create eos_token_ids boolean mask
|
||||||
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
|
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
|
)
|
||||||
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
|
next_token_logits = set_tensor_by_indices_to_value(
|
||||||
|
next_token_logits, eos_token_indices_mask, -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
@@ -806,12 +865,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len,
|
||||||
max_length,
|
max_length,
|
||||||
|
min_length,
|
||||||
do_sample,
|
do_sample,
|
||||||
early_stopping,
|
early_stopping,
|
||||||
temperature,
|
temperature,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -819,6 +880,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
length_penalty,
|
length_penalty,
|
||||||
num_beams,
|
num_beams,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
attention_mask,
|
||||||
):
|
):
|
||||||
""" Generate sequences for each example with beam search.
|
""" Generate sequences for each example with beam search.
|
||||||
"""
|
"""
|
||||||
@@ -829,7 +891,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
for _ in range(batch_size)
|
for _ in range(batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
# scores for each sentence in the beam
|
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||||
if do_sample is False:
|
if do_sample is False:
|
||||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||||
@@ -845,7 +907,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -860,12 +922,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
)
|
)
|
||||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||||
|
|
||||||
if do_sample:
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
if temperature != 1.0:
|
||||||
if temperature != 1.0:
|
next_token_logits = next_token_logits / temperature
|
||||||
next_token_logits = next_token_logits / temperature
|
|
||||||
|
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
if no_repeat_ngram_size > 0:
|
||||||
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||||
|
# create banned_tokens boolean mask
|
||||||
|
banned_tokens_indices_mask = []
|
||||||
|
for banned_tokens_slice in banned_tokens:
|
||||||
|
banned_tokens_indices_mask.append(
|
||||||
|
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = set_tensor_by_indices_to_value(
|
||||||
|
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# set eos token prob to zero if min_length is not reached
|
||||||
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
|
# create eos_token_ids boolean mask
|
||||||
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
|
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
|
)
|
||||||
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
|
next_token_logits = set_tensor_by_indices_to_value(
|
||||||
|
next_token_logits, eos_token_indices_mask, -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate log softmax score
|
||||||
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
_scores = scores + tf.broadcast_to(
|
_scores = scores + tf.broadcast_to(
|
||||||
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
@@ -888,9 +980,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
||||||
next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
|
||||||
else:
|
else:
|
||||||
# do greedy beam search
|
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
|
||||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
|
||||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||||
next_scores = scores + tf.broadcast_to(
|
next_scores = scores + tf.broadcast_to(
|
||||||
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||||
@@ -912,10 +1001,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
# for each sentence
|
# for each sentence
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
|
|
||||||
# if we are done with this sentence
|
|
||||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
|
||||||
tf.reduce_max(next_scores[batch_idx]).numpy()
|
|
||||||
)
|
|
||||||
if done[batch_idx]:
|
if done[batch_idx]:
|
||||||
assert (
|
assert (
|
||||||
len(generated_hyps[batch_idx]) >= num_beams
|
len(generated_hyps[batch_idx]) >= num_beams
|
||||||
@@ -930,29 +1015,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
next_sent_beam = []
|
next_sent_beam = []
|
||||||
|
|
||||||
# next tokens for this sentence
|
# next tokens for this sentence
|
||||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||||
|
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||||
|
):
|
||||||
|
|
||||||
# get beam and token IDs
|
# get beam and token IDs
|
||||||
beam_id = idx // vocab_size
|
beam_id = beam_token_id // vocab_size
|
||||||
token_id = idx % vocab_size
|
token_id = beam_token_id % vocab_size
|
||||||
|
|
||||||
effective_beam_id = batch_idx * num_beams + beam_id
|
effective_beam_id = batch_idx * num_beams + beam_id
|
||||||
# add to generated hypotheses if end of sentence or last iteration
|
# add to generated hypotheses if end of sentence or last iteration
|
||||||
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
||||||
generated_hyps[batch_idx].add(tf.identity(input_ids[effective_beam_id]), score.numpy())
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||||
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
|
continue
|
||||||
|
generated_hyps[batch_idx].add(
|
||||||
|
tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted token if it is not eos_token
|
# add next predicted token if it is not eos_token
|
||||||
next_sent_beam.append((score, token_id, effective_beam_id))
|
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||||
|
|
||||||
# the beam for next step is full
|
# the beam for next step is full
|
||||||
if len(next_sent_beam) == num_beams:
|
if len(next_sent_beam) == num_beams:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# if we are done with this sentence
|
||||||
|
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||||
|
tf.reduce_max(next_scores[batch_idx]).numpy()
|
||||||
|
)
|
||||||
|
|
||||||
# update next beam content
|
# update next beam content
|
||||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||||
next_batch_beam.extend(next_sent_beam)
|
next_batch_beam.extend(next_sent_beam)
|
||||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||||
|
|
||||||
|
# stop when we are done with each sentence
|
||||||
|
if all(done):
|
||||||
|
break
|
||||||
|
|
||||||
# sanity check / prepare next batch
|
# sanity check / prepare next batch
|
||||||
assert len(next_batch_beam) == batch_size * num_beams
|
assert len(next_batch_beam) == batch_size * num_beams
|
||||||
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
|
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
|
||||||
@@ -967,10 +1069,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if past:
|
if past:
|
||||||
past = self._reorder_cache(past, beam_idx)
|
past = self._reorder_cache(past, beam_idx)
|
||||||
|
|
||||||
# stop when we are done with each sentence
|
|
||||||
if all(done):
|
|
||||||
break
|
|
||||||
|
|
||||||
# update current length
|
# update current length
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
@@ -1072,6 +1170,29 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
|||||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
||||||
|
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||||
|
if cur_len + 1 < no_repeat_ngram_size:
|
||||||
|
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||||
|
return [[] for _ in range(num_hypos)]
|
||||||
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||||
|
for idx in range(num_hypos):
|
||||||
|
gen_tokens = prev_input_ids[idx].numpy().tolist()
|
||||||
|
generated_ngram = generated_ngrams[idx]
|
||||||
|
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
||||||
|
prev_ngram_tuple = tuple(ngram[:-1])
|
||||||
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||||
|
|
||||||
|
def _get_generated_ngrams(hypo_idx):
|
||||||
|
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||||
|
start_idx = cur_len + 1 - no_repeat_ngram_size
|
||||||
|
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
|
||||||
|
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||||
|
|
||||||
|
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||||
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
min_length = min_length if min_length is not None else self.config.min_length
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
temperature = temperature if temperature is not None else self.config.temperature
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
cur_len = 1
|
cur_len = 1
|
||||||
self.model.decoder.generation_mode = True
|
|
||||||
|
# put model in generation mode if it has one
|
||||||
|
if hasattr(self.model, "generation_mode"):
|
||||||
|
self.model.decoder.generation_mode = True
|
||||||
else:
|
else:
|
||||||
encoder_inputs = None
|
encoder_inputs = None
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if num_beams > 1:
|
if num_beams > 1:
|
||||||
output = self._generate_beam_search(
|
output = self._generate_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len=cur_len,
|
||||||
max_length,
|
max_length=max_length,
|
||||||
min_length,
|
min_length=min_length,
|
||||||
do_sample,
|
do_sample=do_sample,
|
||||||
early_stopping,
|
early_stopping=early_stopping,
|
||||||
temperature,
|
temperature=temperature,
|
||||||
top_k,
|
top_k=top_k,
|
||||||
top_p,
|
top_p=top_p,
|
||||||
repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
length_penalty,
|
length_penalty=length_penalty,
|
||||||
num_beams,
|
num_beams=num_beams,
|
||||||
vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_inputs,
|
encoder_inputs=encoder_inputs,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(
|
output = self._generate_no_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
cur_len,
|
cur_len=cur_len,
|
||||||
max_length,
|
max_length=max_length,
|
||||||
min_length,
|
min_length=min_length,
|
||||||
do_sample,
|
do_sample=do_sample,
|
||||||
temperature,
|
temperature=temperature,
|
||||||
top_k,
|
top_k=top_k,
|
||||||
top_p,
|
top_p=top_p,
|
||||||
repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
encoder_inputs,
|
encoder_inputs=encoder_inputs,
|
||||||
attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
next_sent_beam = []
|
next_sent_beam = []
|
||||||
|
|
||||||
# next tokens for this sentence
|
# next tokens for this sentence
|
||||||
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
|
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||||
|
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||||
|
):
|
||||||
# get beam and word IDs
|
# get beam and word IDs
|
||||||
beam_id = idx // vocab_size
|
beam_id = beam_token_id // vocab_size
|
||||||
token_id = idx % vocab_size
|
token_id = beam_token_id % vocab_size
|
||||||
|
|
||||||
effective_beam_id = batch_idx * num_beams + beam_id
|
effective_beam_id = batch_idx * num_beams + beam_id
|
||||||
|
|
||||||
# add to generated hypotheses if end of sentence
|
# add to generated hypotheses if end of sentence
|
||||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
||||||
# when passed to num_beams hypotheses, continue
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
if i >= num_beams:
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||||
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
continue
|
continue
|
||||||
generated_hyps[batch_idx].add(
|
generated_hyps[batch_idx].add(
|
||||||
input_ids[effective_beam_id].clone(), score.item(),
|
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted word if it is not eos_token
|
# add next predicted word if it is not eos_token
|
||||||
next_sent_beam.append((score, token_id, effective_beam_id))
|
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||||
|
|
||||||
# the beam for next step is full
|
# the beam for next step is full
|
||||||
if len(next_sent_beam) == num_beams:
|
if len(next_sent_beam) == num_beams:
|
||||||
|
|||||||
Reference in New Issue
Block a user