From 6b0da96b4b7593f25a6919fef89bb47a4cdb06f7 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 4 Nov 2018 15:17:55 +0100 Subject: [PATCH] clean up --- run_classifier.py | 23 +++++++++++++---------- run_squad.py | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index c4a2e7ee61..f6fe12ff98 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -69,7 +69,7 @@ class InputFeatures(object): self.input_mask = input_mask self.segment_ids = segment_ids self.label_id = label_id - + class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" @@ -95,8 +95,8 @@ class DataProcessor(object): for line in reader: lines.append(line) return lines - - + + class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" @@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor): examples.append( InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples - - -def convert_examples_to_features(examples, label_list, max_seq_length, - tokenizer): + + +def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" label_map = {} @@ -380,7 +379,7 @@ def main(): parser.add_argument("--do_lower_case", default=False, action='store_true', - help="Whether to lower case the input text. Should be True for uncased models and False for cased models.") + help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--max_seq_length", default=128, type=int, @@ -424,6 +423,10 @@ def main(): default=False, action='store_true', help="Whether not to use CUDA when available") + parser.add_argument("--accumulate_gradients", + type=int, + default=1, + help="Number of steps to accumulate gradient on (divide the single step batch_size)") parser.add_argument("--local_rank", type=int, default=-1, @@ -448,12 +451,12 @@ def main(): n_gpu = 1 # print("Initializing the distributed backend: NCCL") print("device", device, "n_gpu", n_gpu) - + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu>0: torch.cuda.manual_seed_all(args.seed) - + if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") diff --git a/run_squad.py b/run_squad.py index 51a5ad5963..434fee99de 100644 --- a/run_squad.py +++ b/run_squad.py @@ -18,15 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import argparse import collections import logging import json import math import os -from tqdm import tqdm, trange +import six import random +from tqdm import tqdm, trange import numpy as np import torch