simpler distilbert mask - fix tf tests
This commit is contained in:
@@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
|
|
||||||
dim_per_head = self.dim // self.n_heads
|
dim_per_head = self.dim // self.n_heads
|
||||||
|
|
||||||
assert 2 <= mask.dim() <= 3
|
|
||||||
causal = (mask.dim() == 3)
|
|
||||||
mask_reshp = (bs, 1, 1, k_length)
|
mask_reshp = (bs, 1, 1, k_length)
|
||||||
|
|
||||||
def shape(x):
|
def shape(x):
|
||||||
|
|||||||
@@ -226,9 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
dim_per_head = self.dim // self.n_heads
|
dim_per_head = self.dim // self.n_heads
|
||||||
|
|
||||||
mask_shape = shape_list(mask)
|
|
||||||
assert 2 <= len(mask_shape) <= 3
|
|
||||||
causal = (mask_shape) == 3)
|
|
||||||
mask_reshape = [bs, 1, 1, k_length]
|
mask_reshape = [bs, 1, 1, k_length]
|
||||||
|
|
||||||
def shape(x):
|
def shape(x):
|
||||||
|
|||||||
Reference in New Issue
Block a user