This commit is contained in:
Patrick von Platen
2020-05-18 15:51:40 +02:00
committed by GitHub
parent 31c799a0c9
commit a27c795908
2 changed files with 12 additions and 14 deletions

View File

@@ -929,7 +929,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else:
tokens_to_add = next_token
# add token and increase length by one
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
cur_len = cur_len + 1
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
@@ -955,8 +957,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded
min_sent_length = tf.math.reduce_min(sent_lengths)
max_sent_length = tf.math.reduce_max(sent_lengths)
@@ -970,7 +970,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
)
broad_casted_range = tf.transpose(
tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
)
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
@@ -1205,9 +1205,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
# re-order batch
# re-order batch and update current length
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
cur_len = cur_len + 1
# re-order internal states
if past is not None:
past = self._reorder_cache(past, beam_idx)
@@ -1218,9 +1220,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
)
# update current length
cur_len = cur_len + 1
# finalize all open beam hypotheses and end to generated hypotheses
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps