make GPT2 and CTRL shape consistent between torch and TF

This commit is contained in:
Patrick von Platen
2020-03-04 11:09:45 +01:00
parent 2529b2d37e
commit c4c4c9998a
3 changed files with 30 additions and 13 deletions

View File

@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=1)
present = tf.stack((k, v), axis=0)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])