From 9343a2311b1fad2b35f1d355b6e6e115e0f04002 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 2 Nov 2018 01:31:31 +0100 Subject: [PATCH] =?UTF-8?q?model=20training=20loop=20working=20=E2=80=93?= =?UTF-8?q?=20still=20have=20to=20check=20that=20everything=20is=20exactly?= =?UTF-8?q?=20same?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modeling_pytorch.py | 35 +++++++++++++++++++---------------- run_classifier_pytorch.py | 36 ++++++++++++++++++------------------ 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 9ca262928f..2ed222071e 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -18,21 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import copy import json import math -import re import six -import tensorflow as tf import torch import torch.nn as nn from torch.nn import CrossEntropyLoss def gelu(x): - raise NotImplementedError - # TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) - return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + # OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) class BertConfig(object): @@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): - batch_size = input_ids.size(0) seq_length = input_ids.size(1) - # TODO finich that - position_ids = torch.range().view(batch_size, seq_length) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: - token_type_ids = torch.zeros(batch_size, seq_length) + token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) @@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module): # TODO clean up this (precompute) # MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # `attention_mask` = [B, 1, F, T] - attention_mask = tf.expand_dims(attention_mask, axis=[1]) + # attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. - adder = (1.0 - attention_mask) * -10000.0 + # adder = (1.0 - attention_mask) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_scores += adder + attention_scores += attention_mask # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] @@ -289,7 +284,7 @@ class BERTOutput(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(input_tensor) + hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states @@ -390,6 +385,14 @@ class BertModel(nn.Module): self.pooler = BERTPooler(config) def forward(self, input_ids, token_type_ids, attention_mask): + # We create 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, from_seq_length] + # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] + # It's more simple than the triangular masking of causal attention, just need to + # prepare the broadcast here + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = (1.0 - attention_mask) * -10000.0 + embedding_output = self.embeddings(input_ids, token_type_ids) all_encoder_layers = self.encoder(embedding_output, attention_mask) sequence_output = all_encoder_layers[-1] @@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module): self.classifier = nn.Linear(config.hidden_size, num_labels) def init_weights(m): - if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding: + if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): print("Initializing {}".format(m)) # Slight difference here with the TF version which uses truncated_normal # cf https://github.com/pytorch/pytorch/pull/5617 - m.weight.normal_(config.initializer_range) + m.weight.data.normal_(config.initializer_range) self.apply(init_weights) def forward(self, input_ids, token_type_ids, attention_mask, labels=None): diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index ff90c19314..3ad28726a3 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -484,7 +484,7 @@ def main(): num_train_steps = int( len(train_examples) / args.train_batch_size * args.num_train_epochs) - model = BertForSequenceClassification(bert_config) + model = BertForSequenceClassification(bert_config, len(label_list)) if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) @@ -504,10 +504,10 @@ def main(): logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) - all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long) - all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long) - all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long) - all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long) + all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) + all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.local_rank == -1: @@ -519,12 +519,12 @@ def main(): model.train() global_step = 0 for input_ids, input_mask, segment_ids, label_ids in train_dataloader: - input_ids.to(device) - input_mask.to(device) - segment_ids.to(device) - label_ids.to(device) + input_ids = input_ids.to(device) + input_mask = input_mask.float().to(device) + segment_ids = segment_ids.to(device) + label_ids = label_ids.to(device) - loss = model(input_ids, segment_ids, input_mask, label_ids) + loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss.backward() optimizer.step() global_step += 1 @@ -538,10 +538,10 @@ def main(): logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) - all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long) - all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long) - all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long) - all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long) + all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) + all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) if args.local_rank == -1: @@ -554,10 +554,10 @@ def main(): eval_loss = 0 eval_accuracy = 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: - input_ids.to(device) - input_mask.to(device) - segment_ids.to(device) - label_ids.to(device) + input_ids = input_ids.to(device) + input_mask = input_mask.float().to(device) + segment_ids = segment_ids.to(device) + label_ids = label_ids.to(device) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)