Remove tf.roll wherever not needed (#12512)

It was used in shift_right.
After this change TF code is more similar to Pytorch implementations
Also, TF graphs are optimized (one node less)
This commit is contained in:
Michal Szutenberg
2021-07-07 17:17:30 +02:00
committed by GitHub
parent 95425d546d
commit 0d2bffad31
10 changed files with 18 additions and 28 deletions

View File

@@ -1514,9 +1514,8 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids