fix (#4410)
This commit is contained in:
@@ -586,6 +586,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if data_mask is not None:
|
||||
# all mems can be attended to
|
||||
if mlen > 0:
|
||||
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
|
||||
data_mask = tf.concat([mems_mask, data_mask], axis=1)
|
||||
if attn_mask is None:
|
||||
@@ -598,6 +599,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
if attn_mask is not None:
|
||||
non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
|
||||
if mlen > 0:
|
||||
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
|
||||
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
|
||||
else:
|
||||
@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
# Segment embedding
|
||||
if token_type_ids is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
if mlen > 0:
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
|
||||
else:
|
||||
cat_ids = token_type_ids
|
||||
|
||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
|
||||
@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
||||
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
||||
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user