run_classifier WIP

This commit is contained in:
thomwolf
2018-11-01 21:05:04 +01:00
parent 7af7f8173b
commit 4a0b59e980
3 changed files with 46 additions and 68 deletions

View File

@@ -237,7 +237,7 @@ class BERTSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_x_shape)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer