Fix longformer attention mask type casting when using apex (#4574)
* Fix longformer attention mask casting when using apex * remove extra type casting
This commit is contained in:
@@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
|
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
|
||||||
# use `matmul` because `einsum` crashes sometimes with fp16
|
# use `matmul` because `einsum` crashes sometimes with fp16
|
||||||
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
|
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
|
||||||
attn = torch.matmul(
|
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
|
||||||
selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)
|
|
||||||
).transpose(1, 2)
|
|
||||||
attn_probs = attn_probs.narrow(
|
attn_probs = attn_probs.narrow(
|
||||||
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
|
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
|
||||||
).contiguous()
|
).contiguous()
|
||||||
@@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
]
|
]
|
||||||
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
|
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
|
||||||
len(selection_padding_mask_nonzeros[0]), -1
|
len(selection_padding_mask_nonzeros[0]), -1
|
||||||
).type_as(hidden_states)
|
)
|
||||||
|
|
||||||
context_layer = attn.transpose(0, 1)
|
context_layer = attn.transpose(0, 1)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
|
|||||||
Reference in New Issue
Block a user