Merge pull request #1 from huggingface/multi-gpu-support
Create DataParallel model if several GPUs
This commit is contained in:
@@ -249,6 +249,9 @@ def main():
|
|||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
if n_gpu > 1:
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||||
|
|||||||
@@ -482,6 +482,9 @@ def main():
|
|||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
if n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||||
|
|||||||
@@ -795,6 +795,9 @@ def main():
|
|||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
if n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||||
|
|||||||
Reference in New Issue
Block a user