From efb44a8310f68a1ce4e5633b123c9bff3456f0bf Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 4 Nov 2018 21:25:55 +0100 Subject: [PATCH] distributed in extract features --- extract_features.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/extract_features.py b/extract_features.py index b9625c266d..6ad3a90e00 100644 --- a/extract_features.py +++ b/extract_features.py @@ -25,12 +25,11 @@ import logging import json import re -import tokenization import torch - from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +import tokenization from modeling import BertConfig, BertModel logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -226,8 +225,9 @@ def main(): else: device = torch.device("cuda", args.local_rank) n_gpu = 1 - # print("Initializing the distributed backend: NCCL") - print("device", device, "n_gpu", n_gpu) + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.distributed.init_process_group(backend='nccl') + logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) layer_indexes = [int(x) for x in args.layers.split(",")] @@ -249,9 +249,12 @@ def main(): if args.init_checkpoint is not None: model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) - - if n_gpu > 1: - model = nn.DataParallel(model) + + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank) + elif n_gpu > 1: + model = torch.nn.DataParallel(model) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) @@ -268,7 +271,7 @@ def main(): with open(args.output_file, "w", encoding='utf-8') as writer: for input_ids, input_mask, example_indices in eval_dataloader: input_ids = input_ids.to(device) - input_mask = input_mask.float().to(device) + input_mask = input_mask.to(device) all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) all_encoder_layers = all_encoder_layers