Create DataParallel model if several GPUs

This commit is contained in:
VictorSanh
2018-11-03 10:10:01 -04:00
parent 5889765a7c
commit 5f432480c0
3 changed files with 9 additions and 0 deletions

View File

@@ -250,6 +250,9 @@ def main():
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1:
model = nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

View File

@@ -483,6 +483,9 @@ def main():
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.} {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
], ],

View File

@@ -796,6 +796,9 @@ def main():
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.} {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
], ],