From 0aaedcc02f2a2c65d966e08c8162685e25cab519 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 27 Nov 2018 01:08:37 -0800 Subject: [PATCH] Bug fix in examples;correct t_total for distributed training;run prediction for full dataset --- examples/run_classifier.py | 20 +++++++++-------- examples/run_squad.py | 44 +++++++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index c6acc091ef..2c83b4fe49 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.optimization import BertAdam +from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor): if i == 0: continue guid = "%s-%s" % (set_type, line[0]) - text_a = line[8]) - text_b = line[9]) + text_a = line[8] + text_b = line[9] label = line[-1] examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) @@ -482,7 +483,7 @@ def main(): len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model - model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list), + model = BertForSequenceClassification.from_pretrained(args.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) if args.fp16: model.half() @@ -507,10 +508,13 @@ def main(): {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} ] + t_total = num_train_steps + if args.local_rank != -1: + t_total = t_total // torch.distributed.get_world_size() optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, - t_total=num_train_steps) + t_total=t_total) global_step = 0 if args.do_train: @@ -571,7 +575,7 @@ def main(): model.zero_grad() global_step += 1 - if args.do_eval: + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer) @@ -583,10 +587,8 @@ def main(): all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) - if args.local_rank == -1: - eval_sampler = SequentialSampler(eval_data) - else: - eval_sampler = DistributedSampler(eval_data) + # Run prediction for full data + eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) model.eval() diff --git a/examples/run_squad.py b/examples/run_squad.py index 00d5610afe..e3213189bf 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -25,6 +25,7 @@ import json import math import os import random +import pickle from tqdm import tqdm, trange import numpy as np @@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.optimization import BertAdam +from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -749,6 +751,10 @@ def main(): type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--do_lower_case", + default=True, + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--local_rank", type=int, default=-1, @@ -845,20 +851,34 @@ def main(): {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} ] + t_total = num_train_steps + if args.local_rank != -1: + t_total = t_total // torch.distributed.get_world_size() optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, - t_total=num_train_steps) + t_total=t_total) global_step = 0 if args.do_train: - train_features = convert_examples_to_features( - examples=train_examples, - tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - doc_stride=args.doc_stride, - max_query_length=args.max_query_length, - is_training=True) + cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format( + args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length)) + train_features = None + try: + with open(cached_train_features_file, "rb") as reader: + train_features = pickle.load(reader) + except: + train_features = convert_examples_to_features( + examples=train_examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=True) + if args.local_rank == -1 or torch.distributed.get_rank() == 0: + logger.info(" Saving train features into cached file %s", cached_train_features_file) + with open(cached_train_features_file, "wb") as writer: + train_features = pickle.dump(train_features, writer) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) @@ -913,7 +933,7 @@ def main(): model.zero_grad() global_step += 1 - if args.do_predict: + if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False) eval_features = convert_examples_to_features( @@ -934,10 +954,8 @@ def main(): all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) - if args.local_rank == -1: - eval_sampler = SequentialSampler(eval_data) - else: - eval_sampler = DistributedSampler(eval_data) + # Run prediction for full data + eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) model.eval()