Merge pull request #91 from rodgzilla/convert-examples-code-improvement
run_classifier.py improvements
This commit is contained in:
@@ -196,9 +196,7 @@ class ColaProcessor(DataProcessor):
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {}
|
||||
for (i, label) in enumerate(label_list):
|
||||
label_map[label] = i
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
@@ -207,8 +205,6 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
tokens_b = None
|
||||
if example.text_b:
|
||||
tokens_b = tokenizer.tokenize(example.text_b)
|
||||
|
||||
if tokens_b:
|
||||
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||
# length is less than the specified length.
|
||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||
@@ -216,7 +212,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
else:
|
||||
# Account for [CLS] and [SEP] with "- 2"
|
||||
if len(tokens_a) > max_seq_length - 2:
|
||||
tokens_a = tokens_a[0:(max_seq_length - 2)]
|
||||
tokens_a = tokens_a[:(max_seq_length - 2)]
|
||||
|
||||
# The convention in BERT is:
|
||||
# (a) For sequence pairs:
|
||||
@@ -236,22 +232,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = []
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in tokens_a:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
|
||||
segment_ids = [0] * len(tokens)
|
||||
|
||||
if tokens_b:
|
||||
for token in tokens_b:
|
||||
tokens.append(token)
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
tokens += tokens_b + ["[SEP]"]
|
||||
segment_ids += [1] * (len(tokens_b) + 1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
@@ -260,10 +246,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
padding = [0] * (max_seq_length - len(input_ids))
|
||||
input_ids += padding
|
||||
input_mask += padding
|
||||
segment_ids += padding
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
@@ -437,6 +423,12 @@ def main():
|
||||
"mrpc": MrpcProcessor,
|
||||
}
|
||||
|
||||
num_labels_task = {
|
||||
"cola": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
}
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
@@ -475,6 +467,7 @@ def main():
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
num_labels = num_labels_task[task_name]
|
||||
label_list = processor.get_labels()
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
@@ -488,7 +481,8 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
num_labels = num_labels)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user