From 1d690289898bb4e7f69034333483c687d601015d Mon Sep 17 00:00:00 2001 From: ZhuBaohe Date: Tue, 26 May 2020 20:51:28 +0800 Subject: [PATCH] fix (#4410) --- src/transformers/modeling_tf_xlnet.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 1fb62833ca..25977e84af 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -586,8 +586,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): if data_mask is not None: # all mems can be attended to - mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float) - data_mask = tf.concat([mems_mask, data_mask], axis=1) + if mlen > 0: + mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float) + data_mask = tf.concat([mems_mask, data_mask], axis=1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: @@ -598,7 +599,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=dtype_float) - non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) + if mlen > 0: + non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float) else: non_tgt_mask = None @@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): # Segment embedding if token_type_ids is not None: # Convert `token_type_ids` to one-hot `seg_mat` - mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) - cat_ids = tf.concat([mem_pad, token_type_ids], 0) + if mlen > 0: + mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) + cat_ids = tf.concat([mem_pad, token_type_ids], 0) + else: + cat_ids = token_type_ids # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32) @@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if head_mask is not None: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) - head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) - head_mask = head_mask.to( - dtype=next(self.parameters()).dtype - ) # switch to fload if need + fp16 compatibility + raise NotImplementedError else: head_mask = [None] * self.n_layer