From 23b7138ab495a5f39b648624a8dac73ce8d24f33 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 9 Oct 2019 01:54:44 +0200 Subject: [PATCH] fix #1378 and #1453 --- transformers/modeling_tf_distilbert.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_tf_distilbert.py b/transformers/modeling_tf_distilbert.py index 6ed2844567..fa2dc674af 100644 --- a/transformers/modeling_tf_distilbert.py +++ b/transformers/modeling_tf_distilbert.py @@ -226,8 +226,9 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): dim_per_head = self.dim // self.n_heads - assert 2 <= len(tf.shape(mask)) <= 3 - causal = (len(tf.shape(mask)) == 3) + mask_shape = shape_list(mask) + assert 2 <= len(mask_shape) <= 3 + causal = (mask_shape) == 3) mask_reshape = [bs, 1, 1, k_length] def shape(x):