distributed in extract features

This commit is contained in:
thomwolf
2018-11-04 21:25:55 +01:00
parent d9d7d1a462
commit efb44a8310

View File

@@ -25,12 +25,11 @@ import logging
import json import json
import re import re
import tokenization
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization
from modeling import BertConfig, BertModel from modeling import BertConfig, BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
@@ -226,8 +225,9 @@ def main():
else: else:
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
print("device", device, "n_gpu", n_gpu) torch.distributed.init_process_group(backend='nccl')
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
layer_indexes = [int(x) for x in args.layers.split(",")] layer_indexes = [int(x) for x in args.layers.split(",")]
@@ -249,9 +249,12 @@ def main():
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1: if args.local_rank != -1:
model = nn.DataParallel(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
@@ -268,7 +271,7 @@ def main():
with open(args.output_file, "w", encoding='utf-8') as writer: with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, example_indices in eval_dataloader: for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
all_encoder_layers = all_encoder_layers all_encoder_layers = all_encoder_layers