run_classifier WIP
This commit is contained in:
@@ -237,7 +237,7 @@ class BERTSelfAttention(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_x_shape)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user