bert weight loading from tf
This commit is contained in:
@@ -355,7 +355,7 @@ class BertModel(nn.Module):
|
||||
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
return all_encoder_layers, pooled_output
|
||||
return [embedding_output] + all_encoder_layers, pooled_output
|
||||
|
||||
class BertForSequenceClassification(nn.Module):
|
||||
"""BERT model for classification.
|
||||
|
||||
Reference in New Issue
Block a user