cast bool tensor to long for pytorch < 1.3
This commit is contained in:
committed by
Julien Chaumond
parent
9f75565ea8
commit
4d18199902
@@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
seq_ids = torch.arange(seq_length, device=device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|||||||
Reference in New Issue
Block a user