From 4d1819990294f27ab1cf0113034f52cdb4136eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 12 Nov 2019 17:59:34 +0100 Subject: [PATCH] cast bool tensor to long for pytorch < 1.3 --- transformers/modeling_bert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 1ee3e3f097..0159d58aab 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -675,6 +675,7 @@ class BertModel(BertPreTrainedModel): batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + causal_mask = causal_mask.to(torch.long) # not converting to long will cause errors with pytorch version < 1.3 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: extended_attention_mask = attention_mask[:, None, None, :]