fixes + clean up + mask is long

This commit is contained in:
thomwolf
2018-11-04 21:26:54 +01:00
parent 3ddff783c1
commit d69b0b0e90

View File

@@ -24,8 +24,8 @@ import logging
import json import json
import math import math
import os import os
import six
import random import random
import six
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np import numpy as np
@@ -750,7 +750,7 @@ def main():
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1: if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
@@ -855,7 +855,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, start_positions, end_positions = batch input_ids, input_mask, segment_ids, start_positions, end_positions = batch
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
start_positions = start_positions.to(device) start_positions = start_positions.to(device)
end_positions = start_positions.to(device) end_positions = start_positions.to(device)
@@ -904,12 +904,12 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaluating") logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"): for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask) start_logits, end_logits = model(input_ids, segment_ids, input_mask)