This commit is contained in:
thomwolf
2018-11-04 15:17:55 +01:00
parent 834b485b2e
commit 6b0da96b4b
2 changed files with 15 additions and 12 deletions

View File

@@ -69,7 +69,7 @@ class InputFeatures(object):
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
@@ -95,8 +95,8 @@ class DataProcessor(object):
for line in reader:
lines.append(line)
return lines
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
@@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {}
@@ -380,7 +379,7 @@ def main():
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
default=128,
type=int,
@@ -424,6 +423,10 @@ def main():
default=False,
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--accumulate_gradients",
type=int,
default=1,
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
parser.add_argument("--local_rank",
type=int,
default=-1,
@@ -448,12 +451,12 @@ def main():
n_gpu = 1
# print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")