From 45de034bf899af678d844351ff21ea0444815ddb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 17 Sep 2019 10:25:06 +0200 Subject: [PATCH] fix #1223 --- pytorch_transformers/modeling_xlnet.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index cc350bc956..069a43db69 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel): if data_mask is not None: # all mems can be attended to - mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask) - data_mask = torch.cat([mems_mask, data_mask], dim=1) + if mlen > 0: + mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask) + data_mask = torch.cat([mems_mask, data_mask], dim=1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: @@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel): if attn_mask is not None: non_tgt_mask = -torch.eye(qlen).to(attn_mask) - non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1) + if mlen > 0: + non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1) non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask) else: non_tgt_mask = None @@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel): ##### Segment embedding if token_type_ids is not None: # Convert `token_type_ids` to one-hot `seg_mat` - mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device) - cat_ids = torch.cat([mem_pad, token_type_ids], dim=0) + if mlen > 0: + mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device) + cat_ids = torch.cat([mem_pad, token_type_ids], dim=0) + else: + cat_ids = token_type_ids # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()