simpler distilbert mask - fix tf tests

This commit is contained in:
thomwolf
2019-10-09 04:26:20 +02:00
parent 58b302caf3
commit 1c5079952f
2 changed files with 0 additions and 5 deletions

View File

@@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module):
dim_per_head = self.dim // self.n_heads
assert 2 <= mask.dim() <= 3
causal = (mask.dim() == 3)
mask_reshp = (bs, 1, 1, k_length)
def shape(x):

View File

@@ -226,9 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
dim_per_head = self.dim // self.n_heads
mask_shape = shape_list(mask)
assert 2 <= len(mask_shape) <= 3
causal = (mask_shape) == 3)
mask_reshape = [bs, 1, 1, k_length]
def shape(x):