run_classifier WIP + added classifier head and initialization to the model
This commit is contained in:
@@ -27,6 +27,7 @@ import six
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
def gelu(x):
|
||||
raise NotImplementedError
|
||||
@@ -394,3 +395,30 @@ class BertModel(nn.Module):
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
return all_encoder_layers, pooled_output
|
||||
|
||||
class BertForSequenceClassification(nn.Module):
|
||||
def __init__(self, config, num_labels):
|
||||
super(BertForSequenceClassification, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.normal_(config.initializer_range)
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
|
||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user