model training loop working – still have to check that everything is exactly same

This commit is contained in:
thomwolf
2018-11-02 01:31:31 +01:00
parent f690f0e167
commit 9343a2311b
2 changed files with 37 additions and 34 deletions

View File

@@ -18,21 +18,17 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy import copy
import json import json
import math import math
import re
import six import six
import tensorflow as tf
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
def gelu(x): def gelu(x):
raise NotImplementedError return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) # OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class BertConfig(object): class BertConfig(object):
@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None): def forward(self, input_ids, token_type_ids=None):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
# TODO finich that position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = torch.range().view(batch_size, seq_length) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(batch_size, seq_length) token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
# TODO clean up this (precompute) # TODO clean up this (precompute)
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# `attention_mask` = [B, 1, F, T] # `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1]) # attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - attention_mask) * -10000.0 # adder = (1.0 - attention_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_scores += adder attention_scores += attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_probs` = [B, N, F, T]
@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(input_tensor) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states return hidden_states
@@ -390,6 +385,14 @@ class BertModel(nn.Module):
self.pooler = BERTPooler(config) self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask): def forward(self, input_ids, token_type_ids, attention_mask):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids) embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, attention_mask) all_encoder_layers = self.encoder(embedding_output, attention_mask)
sequence_output = all_encoder_layers[-1] sequence_output = all_encoder_layers[-1]
@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m): def init_weights(m):
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding: if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
print("Initializing {}".format(m)) print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal # Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.normal_(config.initializer_range) m.weight.data.normal_(config.initializer_range)
self.apply(init_weights) self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None): def forward(self, input_ids, token_type_ids, attention_mask, labels=None):

View File

@@ -484,7 +484,7 @@ def main():
num_train_steps = int( num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs) len(train_examples) / args.train_batch_size * args.num_train_epochs)
model = BertForSequenceClassification(bert_config) model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
@@ -504,10 +504,10 @@ def main():
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long) all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: if args.local_rank == -1:
@@ -519,12 +519,12 @@ def main():
model.train() model.train()
global_step = 0 global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids.to(device) input_ids = input_ids.to(device)
input_mask.to(device) input_mask = input_mask.float().to(device)
segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids.to(device) label_ids = label_ids.to(device)
loss = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
@@ -538,10 +538,10 @@ def main():
logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: if args.local_rank == -1:
@@ -554,10 +554,10 @@ def main():
eval_loss = 0 eval_loss = 0
eval_accuracy = 0 eval_accuracy = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids.to(device) input_ids = input_ids.to(device)
input_mask.to(device) input_mask = input_mask.float().to(device)
segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids.to(device) label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)