fix (#4419)
This commit is contained in:
committed by
GitHub
parent
31c799a0c9
commit
a27c795908
@@ -929,7 +929,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
tokens_to_add = next_token
|
tokens_to_add = next_token
|
||||||
|
|
||||||
|
# add token and increase length by one
|
||||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
eos_in_sents = tokens_to_add == eos_token_id
|
eos_in_sents = tokens_to_add == eos_token_id
|
||||||
@@ -955,8 +957,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
|
|
||||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||||
min_sent_length = tf.math.reduce_min(sent_lengths)
|
min_sent_length = tf.math.reduce_min(sent_lengths)
|
||||||
max_sent_length = tf.math.reduce_max(sent_lengths)
|
max_sent_length = tf.math.reduce_max(sent_lengths)
|
||||||
@@ -970,7 +970,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
|
||||||
)
|
)
|
||||||
broad_casted_range = tf.transpose(
|
broad_casted_range = tf.transpose(
|
||||||
tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
|
tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
|
||||||
@@ -1205,9 +1205,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
|
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
|
||||||
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
|
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
|
||||||
|
|
||||||
# re-order batch
|
# re-order batch and update current length
|
||||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
# re-order internal states
|
# re-order internal states
|
||||||
if past is not None:
|
if past is not None:
|
||||||
past = self._reorder_cache(past, beam_idx)
|
past = self._reorder_cache(past, beam_idx)
|
||||||
@@ -1218,9 +1220,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# update current length
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
|
|
||||||
# finalize all open beam hypotheses and end to generated hypotheses
|
# finalize all open beam hypotheses and end to generated hypotheses
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
# Add all open beam hypothesis to generated_hyps
|
# Add all open beam hypothesis to generated_hyps
|
||||||
|
|||||||
@@ -1236,13 +1236,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
else:
|
else:
|
||||||
tokens_to_add = next_token
|
tokens_to_add = next_token
|
||||||
|
|
||||||
|
# add token and increase length by one
|
||||||
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
eos_in_sents = tokens_to_add == eos_token_id
|
eos_in_sents = tokens_to_add == eos_token_id
|
||||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
|
||||||
# unfinished_sents is set to zero if eos in sentence
|
# unfinished_sents is set to zero if eos in sentence
|
||||||
unfinished_sents.mul_((~eos_in_sents).long())
|
unfinished_sents.mul_((~eos_in_sents).long())
|
||||||
|
|
||||||
@@ -1256,8 +1258,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
|
|
||||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
||||||
@@ -1473,9 +1473,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
||||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||||
|
|
||||||
# re-order batch
|
# re-order batch and update current length
|
||||||
input_ids = input_ids[beam_idx, :]
|
input_ids = input_ids[beam_idx, :]
|
||||||
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||||
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
# re-order internal states
|
# re-order internal states
|
||||||
if past is not None:
|
if past is not None:
|
||||||
past = self._reorder_cache(past, beam_idx)
|
past = self._reorder_cache(past, beam_idx)
|
||||||
@@ -1486,9 +1488,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# update current length
|
|
||||||
cur_len = cur_len + 1
|
|
||||||
|
|
||||||
# finalize all open beam hypotheses and end to generated hypotheses
|
# finalize all open beam hypotheses and end to generated hypotheses
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
if done[batch_idx]:
|
if done[batch_idx]:
|
||||||
|
|||||||
Reference in New Issue
Block a user