Remove "double" assignment in TF-BART like models (#9997)
* Replace `attn_weights = attn_wegihts = tf.reshape(...)` with `attn_weights = tf.reshape(...)` and thus remove unintentionally used "double" assignment.
This commit is contained in:
@@ -241,7 +241,7 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
@@ -247,7 +247,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user