From bd2621583be21889aa21a763ce3ae3d8c69e4358 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 1 Oct 2020 18:15:41 +0200 Subject: [PATCH] fix data type (#7513) --- src/transformers/modeling_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 67b9a1a210..c152d76822 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -238,13 +238,20 @@ class ModuleUtilsMixin: seq_ids = torch.arange(seq_length, device=device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] causal_mask = torch.cat( - [torch.ones((batch_size, seq_length, prefix_seq_len), device=device), causal_mask], axis=-1 + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + ), + causal_mask, + ], + axis=-1, ) - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: