Bug fix in examples;correct t_total for distributed training;run prediction for full dataset
This commit is contained in:
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
|
|||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[8])
|
text_a = line[8]
|
||||||
text_b = line[9])
|
text_b = line[9]
|
||||||
label = line[-1]
|
label = line[-1]
|
||||||
examples.append(
|
examples.append(
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
@@ -482,7 +483,7 @@ def main():
|
|||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
|
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
@@ -507,10 +508,13 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||||
]
|
]
|
||||||
|
t_total = num_train_steps
|
||||||
|
if args.local_rank != -1:
|
||||||
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
@@ -571,7 +575,7 @@ def main():
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args.do_eval:
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
eval_examples, label_list, args.max_seq_length, tokenizer)
|
eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
@@ -583,10 +587,8 @@ def main():
|
|||||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
if args.local_rank == -1:
|
# Run prediction for full data
|
||||||
eval_sampler = SequentialSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
else:
|
|
||||||
eval_sampler = DistributedSampler(eval_data)
|
|
||||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import pickle
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
|
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -749,6 +751,10 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
default=True,
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -845,20 +851,34 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||||
]
|
]
|
||||||
|
t_total = num_train_steps
|
||||||
|
if args.local_rank != -1:
|
||||||
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=t_total)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
|
||||||
examples=train_examples,
|
args.bert_model, str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
|
||||||
tokenizer=tokenizer,
|
train_features = None
|
||||||
max_seq_length=args.max_seq_length,
|
try:
|
||||||
doc_stride=args.doc_stride,
|
with open(cached_train_features_file, "rb") as reader:
|
||||||
max_query_length=args.max_query_length,
|
train_features = pickle.load(reader)
|
||||||
is_training=True)
|
except:
|
||||||
|
train_features = convert_examples_to_features(
|
||||||
|
examples=train_examples,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
doc_stride=args.doc_stride,
|
||||||
|
max_query_length=args.max_query_length,
|
||||||
|
is_training=True)
|
||||||
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
|
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||||
|
with open(cached_train_features_file, "wb") as writer:
|
||||||
|
train_features = pickle.dump(train_features, writer)
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num orig examples = %d", len(train_examples))
|
logger.info(" Num orig examples = %d", len(train_examples))
|
||||||
logger.info(" Num split examples = %d", len(train_features))
|
logger.info(" Num split examples = %d", len(train_features))
|
||||||
@@ -913,7 +933,7 @@ def main():
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = read_squad_examples(
|
eval_examples = read_squad_examples(
|
||||||
input_file=args.predict_file, is_training=False)
|
input_file=args.predict_file, is_training=False)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
@@ -934,10 +954,8 @@ def main():
|
|||||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||||
if args.local_rank == -1:
|
# Run prediction for full data
|
||||||
eval_sampler = SequentialSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
else:
|
|
||||||
eval_sampler = DistributedSampler(eval_data)
|
|
||||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user