run_classifier WIP + added classifier head and initialization to the model

This commit is contained in:
thomwolf
2018-11-02 00:27:50 +01:00
parent 4a0b59e980
commit f690f0e167
2 changed files with 128 additions and 103 deletions

View File

@@ -27,6 +27,7 @@ import six
import tensorflow as tf
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
def gelu(x):
raise NotImplementedError
@@ -394,3 +395,30 @@ class BertModel(nn.Module):
sequence_output = all_encoder_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output
class BertForSequenceClassification(nn.Module):
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__()
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m):
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.normal_(config.initializer_range)
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
return loss, logits
else:
return logits