Replace strided slice with tf.expand_dims (#10078)

* Replace tf.newaxis -> tf.expand_dims

* Fix tests

* Fix tests

* Use reshape when a tensors needs a double expand

* Fix GPT2

* Fix GPT2
This commit is contained in:
Julien Plu
2021-02-09 17:48:28 +01:00
committed by GitHub
parent e7381c4596
commit b82fe7d258
17 changed files with 58 additions and 47 deletions

View File

@@ -134,7 +134,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
@@ -570,7 +570,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for