Bug fix in examples;correct t_total for distributed training;run prediction for full dataset

This commit is contained in:
Li Li
2018-11-27 01:08:37 -08:00
parent 32167cdf4b
commit 0aaedcc02f
2 changed files with 42 additions and 22 deletions

View File

@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[8]) text_a = line[8]
text_b = line[9]) text_b = line[9]
label = line[-1] label = line[-1]
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
@@ -482,7 +483,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model # Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list), model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
if args.fp16: if args.fp16:
model.half() model.half()
@@ -507,10 +508,13 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=t_total)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
@@ -571,7 +575,7 @@ def main():
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
if args.do_eval: if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer) eval_examples, label_list, args.max_seq_length, tokenizer)
@@ -583,10 +587,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: # Run prediction for full data
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval() model.eval()

View File

@@ -25,6 +25,7 @@ import json
import math import math
import os import os
import random import random
import pickle
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np import numpy as np
@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
@@ -749,6 +751,10 @@ def main():
type=int, type=int,
default=1, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
default=True,
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
@@ -845,20 +851,34 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=t_total)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
examples=train_examples, args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
tokenizer=tokenizer, train_features = None
max_seq_length=args.max_seq_length, try:
doc_stride=args.doc_stride, with open(cached_train_features_file, "rb") as reader:
max_query_length=args.max_query_length, train_features = pickle.load(reader)
is_training=True) except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
train_features = pickle.dump(train_features, writer)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num split examples = %d", len(train_features))
@@ -913,7 +933,7 @@ def main():
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
if args.do_predict: if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False) input_file=args.predict_file, is_training=False)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
@@ -934,10 +954,8 @@ def main():
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
if args.local_rank == -1: # Run prediction for full data
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
model.eval() model.eval()