From 0d2bffad31e88fe72ec12eb20f5dc8996cbc6497 Mon Sep 17 00:00:00 2001 From: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com> Date: Wed, 7 Jul 2021 17:17:30 +0200 Subject: [PATCH] Remove tf.roll wherever not needed (#12512) It was used in shift_right. After this change TF code is more similar to Pytorch implementations Also, TF graphs are optimized (one node less) --- src/transformers/generation_tf_utils.py | 3 +-- src/transformers/models/bart/modeling_tf_bart.py | 5 ++--- src/transformers/models/blenderbot/modeling_tf_blenderbot.py | 5 ++--- .../models/blenderbot_small/modeling_tf_blenderbot_small.py | 5 ++--- src/transformers/models/led/modeling_tf_led.py | 5 ++--- src/transformers/models/marian/modeling_tf_marian.py | 5 ++--- src/transformers/models/pegasus/modeling_tf_pegasus.py | 5 ++--- src/transformers/models/rag/modeling_tf_rag.py | 3 +-- src/transformers/models/t5/modeling_tf_t5.py | 5 ++--- .../modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 5 ++--- 10 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index b743755dd5..dabbad3a33 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1571,9 +1571,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ) # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) sorted_indices_to_remove = tf.concat( - [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], + [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, :-1]], -1, ) # scatter sorted tensors to original indexing diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index f353aa8a5b..ef57688005 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -61,9 +61,8 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 3e25194806..ab49a3c6a7 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -64,9 +64,8 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index ef0bb6e4f3..1b65d60055 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -62,9 +62,8 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index abd9ebba37..117738485f 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -56,9 +56,8 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 6894c30389..f0bfea8c54 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index defc5ea24e..f71878ecae 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index 90f04ba44e..afd5472f18 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -1365,9 +1365,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 8891f89404..9d5aee4625 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -873,9 +873,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel): decoder_start_token_id is not None ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index d4214787dc..5e1d866c78 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -1514,9 +1514,8 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.roll(input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = tf.where( shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids