clean up + mask is long
This commit is contained in:
@@ -22,10 +22,10 @@ import csv
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
@@ -102,7 +102,7 @@ class MrpcProcessor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
print("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
@@ -420,7 +420,7 @@ def main():
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
||||
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
||||
|
||||
if args.accumulate_gradients < 1:
|
||||
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
||||
@@ -516,7 +516,7 @@ def main():
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
@@ -559,7 +559,7 @@ def main():
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user