be sure we have uint8
This commit is contained in:
@@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
mlen = mems[0].size(0) if mems is not None else 0
|
mlen = mems[0].size(0) if mems is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
if self.same_length:
|
if self.same_length:
|
||||||
all_ones = word_emb.new_ones(qlen, klen)
|
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
|
||||||
mask_len = klen - self.mem_len
|
mask_len = klen - self.mem_len
|
||||||
if mask_len > 0:
|
if mask_len > 0:
|
||||||
mask_shift_len = qlen - mask_len
|
mask_shift_len = qlen - mask_len
|
||||||
@@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
||||||
else:
|
else:
|
||||||
dec_attn_mask = torch.triu(
|
dec_attn_mask = torch.triu(
|
||||||
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
|
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
|
||||||
|
|
||||||
hids = []
|
hids = []
|
||||||
attentions = []
|
attentions = []
|
||||||
|
|||||||
Reference in New Issue
Block a user