From c6d9d5394e6bf461f09e5f3e9b08e333961e590b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Wed, 5 Dec 2018 17:53:09 +0100 Subject: [PATCH] Simplifying code for easier understanding. --- examples/run_classifier.py | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index b2b8ac2630..7cfa39dabf 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -196,9 +196,7 @@ class ColaProcessor(DataProcessor): def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" - label_map = {} - for (i, label) in enumerate(label_list): - label_map[label] = i + label_map = {label : i for i, label in enumerate(label_list)} features = [] for (ex_index, example) in enumerate(examples): @@ -207,8 +205,6 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer tokens_b = None if example.text_b: tokens_b = tokenizer.tokenize(example.text_b) - - if tokens_b: # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" @@ -216,7 +212,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer else: # Account for [CLS] and [SEP] with "- 2" if len(tokens_a) > max_seq_length - 2: - tokens_a = tokens_a[0:(max_seq_length - 2)] + tokens_a = tokens_a[:(max_seq_length - 2)] # The convention in BERT is: # (a) For sequence pairs: @@ -236,22 +232,12 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. - tokens = [] - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in tokens_a: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + segment_ids = [0] * len(tokens) if tokens_b: - for token in tokens_b: - tokens.append(token) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) + tokens += tokens_b + ["[SEP]"] + segment_ids += [1] * (len(tokens_b) + 1) input_ids = tokenizer.convert_tokens_to_ids(tokens) @@ -260,10 +246,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. - while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length