Fix longformer attention mask type casting when using apex (#4574)
* Fix longformer attention mask casting when using apex * remove extra type casting
This commit is contained in:
@@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
|
||||
# use `matmul` because `einsum` crashes sometimes with fp16
|
||||
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
|
||||
attn = torch.matmul(
|
||||
selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)
|
||||
).transpose(1, 2)
|
||||
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
|
||||
attn_probs = attn_probs.narrow(
|
||||
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
|
||||
).contiguous()
|
||||
@@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
]
|
||||
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
|
||||
len(selection_padding_mask_nonzeros[0]), -1
|
||||
).type_as(hidden_states)
|
||||
)
|
||||
|
||||
context_layer = attn.transpose(0, 1)
|
||||
if self.output_attentions:
|
||||
|
||||
Reference in New Issue
Block a user