fixes + clean up + mask is long
This commit is contained in:
10
run_squad.py
10
run_squad.py
@@ -24,8 +24,8 @@ import logging
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import six
|
|
||||||
import random
|
import random
|
||||||
|
import six
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -750,7 +750,7 @@ def main():
|
|||||||
n_gpu = 1
|
n_gpu = 1
|
||||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
||||||
|
|
||||||
if args.accumulate_gradients < 1:
|
if args.accumulate_gradients < 1:
|
||||||
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
||||||
@@ -855,7 +855,7 @@ def main():
|
|||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
start_positions = start_positions.to(device)
|
start_positions = start_positions.to(device)
|
||||||
end_positions = start_positions.to(device)
|
end_positions = start_positions.to(device)
|
||||||
@@ -904,12 +904,12 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
all_results = []
|
all_results = []
|
||||||
logger.info("Start evaluating")
|
logger.info("Start evaluating")
|
||||||
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
|
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
if len(all_results) % 1000 == 0:
|
if len(all_results) % 1000 == 0:
|
||||||
logger.info("Processing example: %d" % (len(all_results)))
|
logger.info("Processing example: %d" % (len(all_results)))
|
||||||
|
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
|
|
||||||
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
||||||
|
|||||||
Reference in New Issue
Block a user