fix #1223
This commit is contained in:
@@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
if data_mask is not None:
|
if data_mask is not None:
|
||||||
# all mems can be attended to
|
# all mems can be attended to
|
||||||
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
|
if mlen > 0:
|
||||||
data_mask = torch.cat([mems_mask, data_mask], dim=1)
|
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
|
||||||
|
data_mask = torch.cat([mems_mask, data_mask], dim=1)
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = data_mask[:, :, :, None]
|
attn_mask = data_mask[:, :, :, None]
|
||||||
else:
|
else:
|
||||||
@@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
|
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
|
||||||
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
|
if mlen > 0:
|
||||||
|
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
|
||||||
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
|
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
|
||||||
else:
|
else:
|
||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
@@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
##### Segment embedding
|
##### Segment embedding
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||||
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
|
if mlen > 0:
|
||||||
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
|
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
|
||||||
|
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
|
||||||
|
else:
|
||||||
|
cat_ids = token_type_ids
|
||||||
|
|
||||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||||
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
|
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
|
||||||
|
|||||||
Reference in New Issue
Block a user