diff --git a/transformers/modeling_distilbert.py b/transformers/modeling_distilbert.py index ebb89f0f95..d3b4ccff5d 100644 --- a/transformers/modeling_distilbert.py +++ b/transformers/modeling_distilbert.py @@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module): dim_per_head = self.dim // self.n_heads - assert 2 <= mask.dim() <= 3 - causal = (mask.dim() == 3) mask_reshp = (bs, 1, 1, k_length) def shape(x): diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index fa2dc674af..f9fe4ca9e9 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -226,9 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): dim_per_head = self.dim // self.n_heads - mask_shape = shape_list(mask) - assert 2 <= len(mask_shape) <= 3 - causal = (mask_shape) == 3) mask_reshape = [bs, 1, 1, k_length] def shape(x):