From 0f96d4b1f76a7e2278a8964aceae6b89da8623de Mon Sep 17 00:00:00 2001 From: John Lehmann Date: Tue, 5 Mar 2019 13:38:28 -0600 Subject: [PATCH] Run classifier processor for SST-2. --- examples/run_classifier.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 83f0683a48..4d267fad3b 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor): return examples +class Sst2Processor(DataProcessor): + """Processor for the SST-2 data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples( + self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, i) + text_a = line[0] + label = line[1] + examples.append( + InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + return examples + + def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Loads a data file into a list of `InputBatch`s.""" @@ -401,10 +432,12 @@ def main(): "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, + "sst-2": Sst2Processor, } num_labels_task = { "cola": 2, + "sst-2": 2, "mnli": 3, "mrpc": 2, } @@ -597,7 +630,7 @@ def main(): model.eval() eval_loss, eval_accuracy = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0 - + for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): input_ids = input_ids.to(device) input_mask = input_mask.to(device)