From 94247ad6cb8404307be31e33cc38ca98a274d21e Mon Sep 17 00:00:00 2001 From: samuelbroscheit Date: Mon, 13 May 2019 12:38:22 +0200 Subject: [PATCH] Make num_train_optimization_steps int --- examples/run_classifier.py | 2 +- examples/run_squad.py | 2 +- examples/run_swag.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 908559d577..94099204de 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -777,7 +777,7 @@ def main(): train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) - num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs + num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() diff --git a/examples/run_squad.py b/examples/run_squad.py index 8ce8b60294..b145303fb0 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -946,7 +946,7 @@ def main(): else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) - num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs + num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() diff --git a/examples/run_swag.py b/examples/run_swag.py index daae3971f7..73cab42830 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -393,7 +393,7 @@ def main(): train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) - num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs + num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()