switch to pytorch DistributedDataParallel
This commit is contained in:
@@ -902,12 +902,12 @@ def main():
|
|||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
try:
|
# try:
|
||||||
from apex.parallel import DistributedDataParallel as DDP
|
# from apex.parallel import DistributedDataParallel as DDP
|
||||||
except ImportError:
|
# except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
model = DDP(model)
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user