From 49296533cad3cff1c0358f948acc72063954ff87 Mon Sep 17 00:00:00 2001 From: Zhangyx Date: Thu, 21 May 2020 21:17:44 +0800 Subject: [PATCH] Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463) * Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website. * Use Split enum + always output the label name Co-authored-by: Julien Chaumond --- examples/bertology/run_bertology.py | 2 +- examples/text-classification/run_glue.py | 40 ++++++++-- src/transformers/data/datasets/glue.py | 36 ++++++--- src/transformers/data/processors/glue.py | 97 +++++++++++++++++------ src/transformers/data/processors/utils.py | 4 + tests/test_trainer.py | 6 +- 6 files changed, 140 insertions(+), 45 deletions(-) diff --git a/examples/bertology/run_bertology.py b/examples/bertology/run_bertology.py index 331da73954..1d498b8646 100644 --- a/examples/bertology/run_bertology.py +++ b/examples/bertology/run_bertology.py @@ -419,7 +419,7 @@ def main(): logger.info("Training/evaluation parameters %s", args) # Prepare dataset for the GLUE task - eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True) + eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev") if args.data_subset > 0: eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset))))) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 080c648938..f7392a2857 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -135,7 +135,8 @@ def main(): # Get datasets train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None + eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None + test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None def compute_metrics(p: EvalPrediction) -> Dict: if output_mode == "classification": @@ -165,7 +166,7 @@ def main(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation - results = {} + eval_results = {} if training_args.do_eval: logger.info("*** Evaluate ***") @@ -173,10 +174,10 @@ def main(): eval_datasets = [eval_dataset] if data_args.task_name == "mnli": mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") - eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True)) + eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev")) for eval_dataset in eval_datasets: - result = trainer.evaluate(eval_dataset=eval_dataset) + eval_result = trainer.evaluate(eval_dataset=eval_dataset) output_eval_file = os.path.join( training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" @@ -184,13 +185,38 @@ def main(): if trainer.is_world_master(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) - for key, value in result.items(): + for key, value in eval_result.items(): logger.info(" %s = %s", key, value) writer.write("%s = %s\n" % (key, value)) - results.update(result) + eval_results.update(eval_result) - return results + if training_args.do_predict: + logging.info("*** Test ***") + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") + test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test")) + + for test_dataset in test_datasets: + predictions = trainer.predict(test_dataset=test_dataset).predictions + if output_mode == "classification": + predictions = np.argmax(predictions, axis=1) + + output_test_file = os.path.join( + training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" + ) + if trainer.is_world_master(): + with open(output_test_file, "w") as writer: + logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if output_mode == "regression": + writer.write("%d\t%3.3f\n" % (index, item)) + else: + item = test_dataset.get_labels()[item] + writer.write("%d\t%s\n" % (index, item)) + return eval_results def _mp_fn(index): diff --git a/src/transformers/data/datasets/glue.py b/src/transformers/data/datasets/glue.py index 944eb83a3a..eaaa40f628 100644 --- a/src/transformers/data/datasets/glue.py +++ b/src/transformers/data/datasets/glue.py @@ -2,7 +2,8 @@ import logging import os import time from dataclasses import dataclass, field -from typing import List, Optional +from enum import Enum +from typing import List, Optional, Union import torch from filelock import FileLock @@ -47,6 +48,12 @@ class GlueDataTrainingArguments: self.task_name = self.task_name.lower() +class Split(Enum): + train = "train" + dev = "dev" + test = "test" + + class GlueDataset(Dataset): """ This will be superseded by a framework-agnostic approach @@ -62,16 +69,21 @@ class GlueDataset(Dataset): args: GlueDataTrainingArguments, tokenizer: PreTrainedTokenizer, limit_length: Optional[int] = None, - evaluate=False, + mode: Union[str, Split] = Split.train, ): self.args = args - processor = glue_processors[args.task_name]() + self.processor = glue_processors[args.task_name]() self.output_mode = glue_output_modes[args.task_name] + if isinstance(mode, str): + try: + mode = Split[mode] + except KeyError: + raise KeyError("mode is not a valid split name") # Load data features from cache or dataset file cached_features_file = os.path.join( args.data_dir, "cached_{}_{}_{}_{}".format( - "dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, + mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, ), ) @@ -88,7 +100,7 @@ class GlueDataset(Dataset): ) else: logger.info(f"Creating features from dataset file at {args.data_dir}") - label_list = processor.get_labels() + label_list = self.processor.get_labels() if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in ( RobertaTokenizer, RobertaTokenizerFast, @@ -96,11 +108,12 @@ class GlueDataset(Dataset): ): # HACK(label indices are swapped in RoBERTa pretrained model) label_list[1], label_list[2] = label_list[2], label_list[1] - examples = ( - processor.get_dev_examples(args.data_dir) - if evaluate - else processor.get_train_examples(args.data_dir) - ) + if mode == Split.dev: + examples = self.processor.get_dev_examples(args.data_dir) + elif mode == Split.test: + examples = self.processor.get_test_examples(args.data_dir) + else: + examples = self.processor.get_train_examples(args.data_dir) if limit_length is not None: examples = examples[:limit_length] self.features = glue_convert_examples_to_features( @@ -122,3 +135,6 @@ class GlueDataset(Dataset): def __getitem__(self, i) -> InputFeatures: return self.features[i] + + def get_labels(self): + return self.processor.get_labels() diff --git a/src/transformers/data/processors/glue.py b/src/transformers/data/processors/glue.py index cc091e2a7c..ecc43f4da4 100644 --- a/src/transformers/data/processors/glue.py +++ b/src/transformers/data/processors/glue.py @@ -126,7 +126,9 @@ def _glue_convert_examples_to_features( label_map = {label: i for i, label in enumerate(label_list)} - def label_from_example(example: InputExample) -> Union[int, float]: + def label_from_example(example: InputExample) -> Union[int, float, None]: + if example.label is None: + return None if output_mode == "classification": return label_map[example.label] elif output_mode == "regression": @@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor): guid = "%s-%s" % (set_type, i) text_a = line[3] text_b = line[4] - label = line[0] + label = None if set_type == "test" else line[0] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched") + def get_labels(self): """See base class.""" return ["contradiction", "entailment", "neutral"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor): guid = "%s-%s" % (set_type, line[0]) text_a = line[8] text_b = line[9] - label = line[-1] + label = None if set_type.startswith("test") else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor): def get_dev_examples(self, data_dir): """See base class.""" - return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched") + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched") class ColaProcessor(DataProcessor): @@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" + test_mode = set_type == "test" + if test_mode: + lines = lines[1:] + text_index = 1 if test_mode else 3 examples = [] for (i, line) in enumerate(lines): guid = "%s-%s" % (set_type, i) - text_a = line[3] - label = line[1] + text_a = line[text_index] + label = None if test_mode else line[1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples @@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) text_a = line[0] - label = line[1] + label = None if set_type == "test" else line[1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples @@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return [None] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor): guid = "%s-%s" % (set_type, line[0]) text_a = line[7] text_b = line[8] - label = line[-1] + label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" + test_mode = set_type == "test" + q1_index = 1 if test_mode else 3 + q2_index = 2 if test_mode else 4 examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, line[0]) try: - text_a = line[3] - text_b = line[4] - label = line[5] + text_a = line[q1_index] + text_b = line[q2_index] + label = None if test_mode else line[5] except IndexError: continue examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) @@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor): def get_dev_examples(self, data_dir): """See base class.""" - return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["entailment", "not_entailment"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor): guid = "%s-%s" % (set_type, line[0]) text_a = line[1] text_b = line[2] - label = line[-1] + label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -453,12 +494,16 @@ class RteProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["entailment", "not_entailment"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -466,7 +511,7 @@ class RteProcessor(DataProcessor): guid = "%s-%s" % (set_type, line[0]) text_a = line[1] text_b = line[2] - label = line[-1] + label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor): """See base class.""" return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): - """Creates examples for the training and dev sets.""" + """Creates examples for the training, dev and test sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: @@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor): guid = "%s-%s" % (set_type, line[0]) text_a = line[1] text_b = line[2] - label = line[-1] + label = None if set_type == "test" else line[-1] examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index eb36551884..0212c58643 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -98,6 +98,10 @@ class DataProcessor: """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for the test set.""" + raise NotImplementedError() + def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 417ebcb5a6..023f7ba6b0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -30,7 +30,7 @@ class DataCollatorIntegrationTest(unittest.TestCase): data_args = GlueDataTrainingArguments( task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True ) - dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) + dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") data_collator = DefaultDataCollator() batch = data_collator.collate_batch(dataset.features) self.assertEqual(batch["labels"].dtype, torch.long) @@ -41,7 +41,7 @@ class DataCollatorIntegrationTest(unittest.TestCase): data_args = GlueDataTrainingArguments( task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True ) - dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) + dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") data_collator = DefaultDataCollator() batch = data_collator.collate_batch(dataset.features) self.assertEqual(batch["labels"].dtype, torch.float) @@ -93,7 +93,7 @@ class TrainerIntegrationTest(unittest.TestCase): data_args = GlueDataTrainingArguments( task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True ) - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) + eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") training_args = TrainingArguments(output_dir="./examples", no_cuda=True) trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)