Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463)
* Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website. * Use Split enum + always output the label name Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -2,7 +2,8 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -47,6 +48,12 @@ class GlueDataTrainingArguments:
|
||||
self.task_name = self.task_name.lower()
|
||||
|
||||
|
||||
class Split(Enum):
|
||||
train = "train"
|
||||
dev = "dev"
|
||||
test = "test"
|
||||
|
||||
|
||||
class GlueDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
@@ -62,16 +69,21 @@ class GlueDataset(Dataset):
|
||||
args: GlueDataTrainingArguments,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
limit_length: Optional[int] = None,
|
||||
evaluate=False,
|
||||
mode: Union[str, Split] = Split.train,
|
||||
):
|
||||
self.args = args
|
||||
processor = glue_processors[args.task_name]()
|
||||
self.processor = glue_processors[args.task_name]()
|
||||
self.output_mode = glue_output_modes[args.task_name]
|
||||
if isinstance(mode, str):
|
||||
try:
|
||||
mode = Split[mode]
|
||||
except KeyError:
|
||||
raise KeyError("mode is not a valid split name")
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(
|
||||
args.data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -88,7 +100,7 @@ class GlueDataset(Dataset):
|
||||
)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
label_list = self.processor.get_labels()
|
||||
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
@@ -96,11 +108,12 @@ class GlueDataset(Dataset):
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(args.data_dir)
|
||||
if evaluate
|
||||
else processor.get_train_examples(args.data_dir)
|
||||
)
|
||||
if mode == Split.dev:
|
||||
examples = self.processor.get_dev_examples(args.data_dir)
|
||||
elif mode == Split.test:
|
||||
examples = self.processor.get_test_examples(args.data_dir)
|
||||
else:
|
||||
examples = self.processor.get_train_examples(args.data_dir)
|
||||
if limit_length is not None:
|
||||
examples = examples[:limit_length]
|
||||
self.features = glue_convert_examples_to_features(
|
||||
@@ -122,3 +135,6 @@ class GlueDataset(Dataset):
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
def get_labels(self):
|
||||
return self.processor.get_labels()
|
||||
|
||||
@@ -126,7 +126,9 @@ def _glue_convert_examples_to_features(
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
def label_from_example(example: InputExample) -> Union[int, float]:
|
||||
def label_from_example(example: InputExample) -> Union[int, float, None]:
|
||||
if example.label is None:
|
||||
return None
|
||||
if output_mode == "classification":
|
||||
return label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
@@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[3]
|
||||
text_b = line[4]
|
||||
label = line[0]
|
||||
label = None if set_type == "test" else line[0]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[8]
|
||||
text_b = line[9]
|
||||
label = line[-1]
|
||||
label = None if set_type.startswith("test") else line[-1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor):
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
|
||||
|
||||
|
||||
class ColaProcessor(DataProcessor):
|
||||
@@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
test_mode = set_type == "test"
|
||||
if test_mode:
|
||||
lines = lines[1:]
|
||||
text_index = 1 if test_mode else 3
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[3]
|
||||
label = line[1]
|
||||
text_a = line[text_index]
|
||||
label = None if test_mode else line[1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
@@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = line[0]
|
||||
label = line[1]
|
||||
label = None if set_type == "test" else line[1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
@@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return [None]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[7]
|
||||
text_b = line[8]
|
||||
label = line[-1]
|
||||
label = None if set_type == "test" else line[-1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
test_mode = set_type == "test"
|
||||
q1_index = 1 if test_mode else 3
|
||||
q2_index = 2 if test_mode else 4
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
try:
|
||||
text_a = line[3]
|
||||
text_b = line[4]
|
||||
label = line[5]
|
||||
text_a = line[q1_index]
|
||||
text_b = line[q2_index]
|
||||
label = None if test_mode else line[5]
|
||||
except IndexError:
|
||||
continue
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
@@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor):
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
label = None if set_type == "test" else line[-1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@@ -453,12 +494,16 @@ class RteProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -466,7 +511,7 @@ class RteProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
label = None if set_type == "test" else line[-1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
@@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
"""Creates examples for the training, dev and test sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
@@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor):
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
label = None if set_type == "test" else line[-1]
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
@@ -98,6 +98,10 @@ class DataProcessor:
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the test set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user