update all tf.shape and tensor.shape to shape_list
This commit is contained in:
@@ -92,7 +92,7 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
if self.scale:
|
||||
dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
|
||||
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
|
||||
w = w / tf.math.sqrt(dk)
|
||||
|
||||
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
||||
|
||||
Reference in New Issue
Block a user