fixing model to add torchscript, embedding resizing, head pruning and masking + tests

This commit is contained in:
thomwolf
2019-08-28 13:22:45 +02:00
parent 62df4ba59a
commit c9bce1811c
3 changed files with 253 additions and 138 deletions

View File

@@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):