clean up examples - added squad example and test

This commit is contained in:
thomwolf
2019-07-12 14:16:06 +02:00
parent 699bc7e86e
commit 936e813c84
14 changed files with 1266 additions and 1513 deletions

View File

@@ -306,9 +306,9 @@ def main():
help="Set this flag if you are using an uncased model.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU for training.")
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU for evaluation.")
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
@@ -395,8 +395,7 @@ def main():
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = ""
for key in MODEL_CLASSES:
@@ -409,7 +408,7 @@ def main():
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
if args.local_rank == 0:
torch.distributed.barrier()
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
# Distributed and parrallel training
model.to(args.device)
@@ -422,6 +421,7 @@ def main():
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
@@ -450,6 +450,7 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
@@ -459,7 +460,7 @@ def main():
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1]
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step)