diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index e334322290..9e2808f218 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: - all_ones = word_emb.new_ones(qlen, klen) + all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len @@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 else: dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None] + word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] hids = [] attentions = []