From a27c795908d697076b1737c6294011a3b88f04a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 18 May 2020 15:51:40 +0200 Subject: [PATCH] fix (#4419) --- src/transformers/modeling_tf_utils.py | 13 ++++++------- src/transformers/modeling_utils.py | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index ef98415fb4..1ee5b60682 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -929,7 +929,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): else: tokens_to_add = next_token + # add token and increase length by one input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) + cur_len = cur_len + 1 if eos_token_id is not None: eos_in_sents = tokens_to_add == eos_token_id @@ -955,8 +957,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 ) - cur_len = cur_len + 1 - # if there are different sentences lengths in the batch, some batches have to be padded min_sent_length = tf.math.reduce_min(sent_lengths) max_sent_length = tf.math.reduce_max(sent_lengths) @@ -970,7 +970,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length] ) broad_casted_range = tf.transpose( - tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size]) + tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size]) ) decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding) @@ -1205,9 +1205,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32) beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32) - # re-order batch + # re-order batch and update current length input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx]) input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1) + cur_len = cur_len + 1 + # re-order internal states if past is not None: past = self._reorder_cache(past, beam_idx) @@ -1218,9 +1220,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 ) - # update current length - cur_len = cur_len + 1 - # finalize all open beam hypotheses and end to generated hypotheses for batch_idx in range(batch_size): # Add all open beam hypothesis to generated_hyps diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d5d06134bb..cd98907a94 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1236,13 +1236,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): else: tokens_to_add = next_token + # add token and increase length by one input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 if eos_token_id is not None: eos_in_sents = tokens_to_add == eos_token_id # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() - sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1) + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) # unfinished_sents is set to zero if eos in sentence unfinished_sents.mul_((~eos_in_sents).long()) @@ -1256,8 +1258,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) - cur_len = cur_len + 1 - # if there are different sentences lengths in the batch, some batches have to be padded if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" @@ -1473,9 +1473,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam]) - # re-order batch + # re-order batch and update current length input_ids = input_ids[beam_idx, :] input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) + cur_len = cur_len + 1 + # re-order internal states if past is not None: past = self._reorder_cache(past, beam_idx) @@ -1486,9 +1488,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) - # update current length - cur_len = cur_len + 1 - # finalize all open beam hypotheses and end to generated hypotheses for batch_idx in range(batch_size): if done[batch_idx]: