fix (#4410)
This commit is contained in:
@@ -586,8 +586,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if data_mask is not None:
|
if data_mask is not None:
|
||||||
# all mems can be attended to
|
# all mems can be attended to
|
||||||
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
|
if mlen > 0:
|
||||||
data_mask = tf.concat([mems_mask, data_mask], axis=1)
|
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
|
||||||
|
data_mask = tf.concat([mems_mask, data_mask], axis=1)
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = data_mask[:, :, :, None]
|
attn_mask = data_mask[:, :, :, None]
|
||||||
else:
|
else:
|
||||||
@@ -598,7 +599,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
|
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
|
||||||
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
|
if mlen > 0:
|
||||||
|
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
|
||||||
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
|
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
|
||||||
else:
|
else:
|
||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
# Segment embedding
|
# Segment embedding
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
if mlen > 0:
|
||||||
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
|
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||||
|
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
|
||||||
|
else:
|
||||||
|
cat_ids = token_type_ids
|
||||||
|
|
||||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||||
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
|
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
|
||||||
@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
||||||
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
if head_mask.dim() == 1:
|
raise NotImplementedError
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
|
||||||
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
|
||||||
head_mask = head_mask.to(
|
|
||||||
dtype=next(self.parameters()).dtype
|
|
||||||
) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.n_layer
|
head_mask = [None] * self.n_layer
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user