be sure we have uint8
This commit is contained in:
@@ -1135,7 +1135,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
mlen = mems[0].size(0) if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
if self.same_length:
|
||||
all_ones = word_emb.new_ones(qlen, klen)
|
||||
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
|
||||
mask_len = klen - self.mem_len
|
||||
if mask_len > 0:
|
||||
mask_shift_len = qlen - mask_len
|
||||
@@ -1145,7 +1145,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
||||
else:
|
||||
dec_attn_mask = torch.triu(
|
||||
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
|
||||
word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
|
||||
|
||||
hids = []
|
||||
attentions = []
|
||||
|
||||
Reference in New Issue
Block a user