update all tf.shape and tensor.shape to shape_list

This commit is contained in:
thomwolf
2019-11-28 15:51:43 +01:00
committed by Lysandre Debut
parent 1ab8dc44b3
commit adb5c79ff2
13 changed files with 48 additions and 54 deletions

View File

@@ -95,7 +95,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
batch_size = q.shape[0]
batch_size = shape_list(q)[0]
q = self.Wq(q)
k = self.Wk(k)