distributed traing t_total
This commit is contained in:
@@ -234,10 +234,11 @@ def main():
|
|||||||
train_sampler = RandomSampler(train_data)
|
train_sampler = RandomSampler(train_data)
|
||||||
else:
|
else:
|
||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
if args.local_rank != -1:
|
# if args.local_rank != -1:
|
||||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
# num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
|
|||||||
Reference in New Issue
Block a user