Make num_train_optimization_steps int
This commit is contained in:
@@ -777,7 +777,7 @@ def main():
|
|||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
|||||||
@@ -946,7 +946,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ def main():
|
|||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user