From af62cc5f20da128980639f31a54e68bff399a11c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 11 Feb 2019 14:06:32 +0100 Subject: [PATCH] fix run_squad example --- examples/run_squad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 1d7c49c326..0e9aec81a1 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -881,7 +881,7 @@ def main(): train_examples = read_squad_examples( input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative) num_train_optimization_steps = int( - len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs + len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()