update all tf.shape and tensor.shape to shape_list

This commit is contained in:
thomwolf
2019-11-28 15:51:43 +01:00
committed by Lysandre Debut
parent 1ab8dc44b3
commit adb5c79ff2
13 changed files with 48 additions and 54 deletions

View File

@@ -92,7 +92,7 @@ class TFAttention(tf.keras.layers.Layer):
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
w = w / tf.math.sqrt(dk)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.