simpler distilbert mask - fix tf tests

This commit is contained in:
thomwolf
2019-10-09 04:26:20 +02:00
parent 58b302caf3
commit 1c5079952f
2 changed files with 0 additions and 5 deletions

View File

@@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module):
dim_per_head = self.dim // self.n_heads
assert 2 <= mask.dim() <= 3
causal = (mask.dim() == 3)
mask_reshp = (bs, 1, 1, k_length)
def shape(x):