update all tf.shape and tensor.shape to shape_list
This commit is contained in:
@@ -95,7 +95,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
|
||||
batch_size = q.shape[0]
|
||||
batch_size = shape_list(q)[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
k = self.Wk(k)
|
||||
|
||||
Reference in New Issue
Block a user