bert weight loading from tf

This commit is contained in:
lukovnikov
2018-11-06 17:47:03 +01:00
parent 907d3569c1
commit 4e52188433
3 changed files with 102 additions and 27 deletions

View File

@@ -355,7 +355,7 @@ class BertModel(nn.Module):
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
sequence_output = all_encoder_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output
return [embedding_output] + all_encoder_layers, pooled_output
class BertForSequenceClassification(nn.Module):
"""BERT model for classification.