Avoid using tf.tile in embeddings for TF models (#14735)
* avoid tf.tile in embeddings * remove more tf.tile in embeddings * clean Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -83,7 +83,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.initializer_range = config.initializer_range
|
||||
self.embeddings_sum = tf.keras.layers.Add()
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
@@ -142,9 +141,8 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
||||
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
|
||||
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
|
||||
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
||||
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user