From 3ddff783c142f56b3b48cb1ab8ec919281ee9800 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 4 Nov 2018 21:26:44 +0100 Subject: [PATCH] clean up + mask is long --- run_classifier.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index 2b38b2be44..f1a102253a 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -22,10 +22,10 @@ import csv import os import logging import argparse - import random -import numpy as np from tqdm import tqdm, trange + +import numpy as np import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -102,7 +102,7 @@ class MrpcProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" - print("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) + logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") @@ -420,7 +420,7 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') - print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) + logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) if args.accumulate_gradients < 1: raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( @@ -516,7 +516,7 @@ def main(): nb_tr_examples, nb_tr_steps = 0, 0 for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")): input_ids = input_ids.to(device) - input_mask = input_mask.float().to(device) + input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) label_ids = label_ids.to(device) @@ -559,7 +559,7 @@ def main(): nb_eval_steps, nb_eval_examples = 0, 0 for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: input_ids = input_ids.to(device) - input_mask = input_mask.float().to(device) + input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) label_ids = label_ids.to(device)