model training loop working – still have to check that everything is exactly same
This commit is contained in:
@@ -18,21 +18,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import six
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
def gelu(x):
|
||||
raise NotImplementedError
|
||||
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
# OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
batch_size = input_ids.size(0)
|
||||
seq_length = input_ids.size(1)
|
||||
# TODO finich that
|
||||
position_ids = torch.range().view(batch_size, seq_length)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(batch_size, seq_length)
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
|
||||
# TODO clean up this (precompute)
|
||||
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
|
||||
# `attention_mask` = [B, 1, F, T]
|
||||
attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
||||
# attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
adder = (1.0 - attention_mask) * -10000.0
|
||||
# adder = (1.0 - attention_mask) * -10000.0
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_scores += adder
|
||||
attention_scores += attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
# `attention_probs` = [B, N, F, T]
|
||||
@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(input_tensor)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
@@ -390,6 +385,14 @@ class BertModel(nn.Module):
|
||||
self.pooler = BERTPooler(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
# We create 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, from_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
|
||||
# It's more simple than the triangular masking of causal attention, just need to
|
||||
# prepare the broadcast here
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
all_encoder_layers = self.encoder(embedding_output, attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
|
||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.normal_(config.initializer_range)
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
|
||||
|
||||
@@ -484,7 +484,7 @@ def main():
|
||||
num_train_steps = int(
|
||||
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
||||
|
||||
model = BertForSequenceClassification(bert_config)
|
||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
@@ -504,10 +504,10 @@ def main():
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_steps)
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
@@ -519,12 +519,12 @@ def main():
|
||||
model.train()
|
||||
global_step = 0
|
||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
input_ids.to(device)
|
||||
input_mask.to(device)
|
||||
segment_ids.to(device)
|
||||
label_ids.to(device)
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
@@ -538,10 +538,10 @@ def main():
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
@@ -554,10 +554,10 @@ def main():
|
||||
eval_loss = 0
|
||||
eval_accuracy = 0
|
||||
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||
input_ids.to(device)
|
||||
input_mask.to(device)
|
||||
segment_ids.to(device)
|
||||
label_ids.to(device)
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||
|
||||
Reference in New Issue
Block a user