Make num_train_optimization_steps int

This commit is contained in:
samuelbroscheit
2019-05-13 12:38:22 +02:00
parent 49a77ac16f
commit 94247ad6cb
3 changed files with 3 additions and 3 deletions

View File

@@ -393,7 +393,7 @@ def main():
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()