From 1c5079952f5f10eeac4cb6801b4fd1f36b0eff73 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 9 Oct 2019 04:26:20 +0200 Subject: [PATCH] simpler distilbert mask - fix tf tests --- transformers/modeling_distilbert.py | 2 -- transformers/modeling_tf_distilbert.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/transformers/modeling_distilbert.py b/transformers/modeling_distilbert.py index ebb89f0f95..d3b4ccff5d 100644 --- a/transformers/modeling_distilbert.py +++ b/transformers/modeling_distilbert.py @@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module): dim_per_head = self.dim // self.n_heads - assert 2 <= mask.dim() <= 3 - causal = (mask.dim() == 3) mask_reshp = (bs, 1, 1, k_length) def shape(x): diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index fa2dc674af..f9fe4ca9e9 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -226,9 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): dim_per_head = self.dim // self.n_heads - mask_shape = shape_list(mask) - assert 2 <= len(mask_shape) <= 3 - causal = (mask_shape) == 3) mask_reshape = [bs, 1, 1, k_length] def shape(x):