clean up + mask is long

This commit is contained in:
thomwolf
2018-11-04 21:26:44 +01:00
parent 88c1037991
commit 3ddff783c1

View File

@@ -22,10 +22,10 @@ import csv
import os import os
import logging import logging
import argparse import argparse
import random import random
import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@@ -102,7 +102,7 @@ class MrpcProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
print("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
@@ -420,7 +420,7 @@ def main():
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1: if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
@@ -516,7 +516,7 @@ def main():
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
@@ -559,7 +559,7 @@ def main():
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) label_ids = label_ids.to(device)