diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 1ee3e3f097..0159d58aab 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel): batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: extended_attention_mask = attention_mask[:, None, None, :]