DDP update
This commit is contained in:
@@ -907,7 +907,7 @@ def main():
|
|||||||
# except ImportError:
|
# except ImportError:
|
||||||
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user