Torch 1.1.0 compatibility + FP16 O1 + TF checkpoints

Co-authored-by: wassname
This commit is contained in:
Lysandre
2019-11-11 15:12:54 -05:00
committed by Lysandre Debut
parent b18509c208
commit c9cb7f8a0f
2 changed files with 10 additions and 3 deletions

View File

@@ -203,8 +203,8 @@ class AlbertAttention(BertSelfAttention):
# Should find a better way to do this
w = self.dense.weight.T.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
b = self.dense.bias
w = self.dense.weight.t().view(self.num_attention_heads, self.attention_head_size, self.hidden_size).to(context_layer.dtype)
b = self.dense.bias.to(context_layer.dtype)
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.dropout(projected_context_layer)