Clean up a little bit

This commit is contained in:
samuel.broscheit
2019-05-12 00:31:10 +02:00
parent 3bf3f9596f
commit 49a77ac16f
3 changed files with 72 additions and 70 deletions

View File

@@ -894,14 +894,31 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_optimization_steps = None
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train:
# Prepare data loader
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
train_features = None
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
@@ -933,25 +950,8 @@ def main():
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
# Prepare optimizer
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used
@@ -987,8 +987,8 @@ def main():
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0
if args.do_train:
global_step = 0
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))