From 33b7532e69aaed83f31e8ad5da0a36c28e140da2 Mon Sep 17 00:00:00 2001 From: Wei Fang Date: Sat, 30 May 2020 00:13:30 +0800 Subject: [PATCH] Fix longformer attention mask type casting when using apex (#4574) * Fix longformer attention mask casting when using apex * remove extra type casting --- src/transformers/modeling_longformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 2a8e862c2e..d254c115fa 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module): selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] # use `matmul` because `einsum` crashes sometimes with fp16 # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) - attn = torch.matmul( - selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs) - ).transpose(1, 2) + attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2) attn_probs = attn_probs.narrow( -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch ).contiguous() @@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module): ] attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( len(selection_padding_mask_nonzeros[0]), -1 - ).type_as(hidden_states) + ) context_layer = attn.transpose(0, 1) if self.output_attentions: