Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463)
* Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website. * Use Split enum + always output the label name Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -419,7 +419,7 @@ def main():
|
|||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
# Prepare dataset for the GLUE task
|
# Prepare dataset for the GLUE task
|
||||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
|
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
|
||||||
if args.data_subset > 0:
|
if args.data_subset > 0:
|
||||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||||
|
|||||||
@@ -135,7 +135,8 @@ def main():
|
|||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
|
||||||
|
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
|
||||||
|
|
||||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
@@ -165,7 +166,7 @@ def main():
|
|||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
eval_results = {}
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
@@ -173,10 +174,10 @@ def main():
|
|||||||
eval_datasets = [eval_dataset]
|
eval_datasets = [eval_dataset]
|
||||||
if data_args.task_name == "mnli":
|
if data_args.task_name == "mnli":
|
||||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||||
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
|
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
|
||||||
|
|
||||||
for eval_dataset in eval_datasets:
|
for eval_dataset in eval_datasets:
|
||||||
result = trainer.evaluate(eval_dataset=eval_dataset)
|
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||||
|
|
||||||
output_eval_file = os.path.join(
|
output_eval_file = os.path.join(
|
||||||
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
|
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
|
||||||
@@ -184,13 +185,38 @@ def main():
|
|||||||
if trainer.is_world_master():
|
if trainer.is_world_master():
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
|
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
|
||||||
for key, value in result.items():
|
for key, value in eval_result.items():
|
||||||
logger.info(" %s = %s", key, value)
|
logger.info(" %s = %s", key, value)
|
||||||
writer.write("%s = %s\n" % (key, value))
|
writer.write("%s = %s\n" % (key, value))
|
||||||
|
|
||||||
results.update(result)
|
eval_results.update(eval_result)
|
||||||
|
|
||||||
return results
|
if training_args.do_predict:
|
||||||
|
logging.info("*** Test ***")
|
||||||
|
test_datasets = [test_dataset]
|
||||||
|
if data_args.task_name == "mnli":
|
||||||
|
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||||
|
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
|
||||||
|
|
||||||
|
for test_dataset in test_datasets:
|
||||||
|
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
||||||
|
if output_mode == "classification":
|
||||||
|
predictions = np.argmax(predictions, axis=1)
|
||||||
|
|
||||||
|
output_test_file = os.path.join(
|
||||||
|
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
|
||||||
|
)
|
||||||
|
if trainer.is_world_master():
|
||||||
|
with open(output_test_file, "w") as writer:
|
||||||
|
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
|
||||||
|
writer.write("index\tprediction\n")
|
||||||
|
for index, item in enumerate(predictions):
|
||||||
|
if output_mode == "regression":
|
||||||
|
writer.write("%d\t%3.3f\n" % (index, item))
|
||||||
|
else:
|
||||||
|
item = test_dataset.get_labels()[item]
|
||||||
|
writer.write("%d\t%s\n" % (index, item))
|
||||||
|
return eval_results
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
@@ -47,6 +48,12 @@ class GlueDataTrainingArguments:
|
|||||||
self.task_name = self.task_name.lower()
|
self.task_name = self.task_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class Split(Enum):
|
||||||
|
train = "train"
|
||||||
|
dev = "dev"
|
||||||
|
test = "test"
|
||||||
|
|
||||||
|
|
||||||
class GlueDataset(Dataset):
|
class GlueDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
This will be superseded by a framework-agnostic approach
|
This will be superseded by a framework-agnostic approach
|
||||||
@@ -62,16 +69,21 @@ class GlueDataset(Dataset):
|
|||||||
args: GlueDataTrainingArguments,
|
args: GlueDataTrainingArguments,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
limit_length: Optional[int] = None,
|
limit_length: Optional[int] = None,
|
||||||
evaluate=False,
|
mode: Union[str, Split] = Split.train,
|
||||||
):
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
processor = glue_processors[args.task_name]()
|
self.processor = glue_processors[args.task_name]()
|
||||||
self.output_mode = glue_output_modes[args.task_name]
|
self.output_mode = glue_output_modes[args.task_name]
|
||||||
|
if isinstance(mode, str):
|
||||||
|
try:
|
||||||
|
mode = Split[mode]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("mode is not a valid split name")
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
args.data_dir,
|
args.data_dir,
|
||||||
"cached_{}_{}_{}_{}".format(
|
"cached_{}_{}_{}_{}".format(
|
||||||
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,7 +100,7 @@ class GlueDataset(Dataset):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
||||||
label_list = processor.get_labels()
|
label_list = self.processor.get_labels()
|
||||||
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
RobertaTokenizerFast,
|
RobertaTokenizerFast,
|
||||||
@@ -96,11 +108,12 @@ class GlueDataset(Dataset):
|
|||||||
):
|
):
|
||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
examples = (
|
if mode == Split.dev:
|
||||||
processor.get_dev_examples(args.data_dir)
|
examples = self.processor.get_dev_examples(args.data_dir)
|
||||||
if evaluate
|
elif mode == Split.test:
|
||||||
else processor.get_train_examples(args.data_dir)
|
examples = self.processor.get_test_examples(args.data_dir)
|
||||||
)
|
else:
|
||||||
|
examples = self.processor.get_train_examples(args.data_dir)
|
||||||
if limit_length is not None:
|
if limit_length is not None:
|
||||||
examples = examples[:limit_length]
|
examples = examples[:limit_length]
|
||||||
self.features = glue_convert_examples_to_features(
|
self.features = glue_convert_examples_to_features(
|
||||||
@@ -122,3 +135,6 @@ class GlueDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, i) -> InputFeatures:
|
def __getitem__(self, i) -> InputFeatures:
|
||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
return self.processor.get_labels()
|
||||||
|
|||||||
@@ -126,7 +126,9 @@ def _glue_convert_examples_to_features(
|
|||||||
|
|
||||||
label_map = {label: i for i, label in enumerate(label_list)}
|
label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
|
||||||
def label_from_example(example: InputExample) -> Union[int, float]:
|
def label_from_example(example: InputExample) -> Union[int, float, None]:
|
||||||
|
if example.label is None:
|
||||||
|
return None
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
return label_map[example.label]
|
return label_map[example.label]
|
||||||
elif output_mode == "regression":
|
elif output_mode == "regression":
|
||||||
@@ -180,12 +182,16 @@ class MrpcProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1"]
|
return ["0", "1"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -193,7 +199,7 @@ class MrpcProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[3]
|
text_a = line[3]
|
||||||
text_b = line[4]
|
text_b = line[4]
|
||||||
label = line[0]
|
label = None if set_type == "test" else line[0]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -218,12 +224,16 @@ class MnliProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["contradiction", "entailment", "neutral"]
|
return ["contradiction", "entailment", "neutral"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -231,7 +241,7 @@ class MnliProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[8]
|
text_a = line[8]
|
||||||
text_b = line[9]
|
text_b = line[9]
|
||||||
label = line[-1]
|
label = None if set_type.startswith("test") else line[-1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -241,7 +251,11 @@ class MnliMismatchedProcessor(MnliProcessor):
|
|||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
|
||||||
|
|
||||||
|
|
||||||
class ColaProcessor(DataProcessor):
|
class ColaProcessor(DataProcessor):
|
||||||
@@ -264,17 +278,25 @@ class ColaProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1"]
|
return ["0", "1"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
|
test_mode = set_type == "test"
|
||||||
|
if test_mode:
|
||||||
|
lines = lines[1:]
|
||||||
|
text_index = 1 if test_mode else 3
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[3]
|
text_a = line[text_index]
|
||||||
label = line[1]
|
label = None if test_mode else line[1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -299,19 +321,23 @@ class Sst2Processor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1"]
|
return ["0", "1"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % (set_type, i)
|
guid = "%s-%s" % (set_type, i)
|
||||||
text_a = line[0]
|
text_a = line[0]
|
||||||
label = line[1]
|
label = None if set_type == "test" else line[1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -336,12 +362,16 @@ class StsbProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return [None]
|
return [None]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -349,7 +379,7 @@ class StsbProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[7]
|
text_a = line[7]
|
||||||
text_b = line[8]
|
text_b = line[8]
|
||||||
label = line[-1]
|
label = None if set_type == "test" else line[-1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -374,21 +404,28 @@ class QqpProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1"]
|
return ["0", "1"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
|
test_mode = set_type == "test"
|
||||||
|
q1_index = 1 if test_mode else 3
|
||||||
|
q2_index = 2 if test_mode else 4
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
continue
|
continue
|
||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
try:
|
try:
|
||||||
text_a = line[3]
|
text_a = line[q1_index]
|
||||||
text_b = line[4]
|
text_b = line[q2_index]
|
||||||
label = line[5]
|
label = None if test_mode else line[5]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
continue
|
continue
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
@@ -413,14 +450,18 @@ class QnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["entailment", "not_entailment"]
|
return ["entailment", "not_entailment"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -428,7 +469,7 @@ class QnliProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = None if set_type == "test" else line[-1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -453,12 +494,16 @@ class RteProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["entailment", "not_entailment"]
|
return ["entailment", "not_entailment"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -466,7 +511,7 @@ class RteProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = None if set_type == "test" else line[-1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -491,12 +536,16 @@ class WnliProcessor(DataProcessor):
|
|||||||
"""See base class."""
|
"""See base class."""
|
||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
return ["0", "1"]
|
return ["0", "1"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
"""Creates examples for the training and dev sets."""
|
"""Creates examples for the training, dev and test sets."""
|
||||||
examples = []
|
examples = []
|
||||||
for (i, line) in enumerate(lines):
|
for (i, line) in enumerate(lines):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -504,7 +553,7 @@ class WnliProcessor(DataProcessor):
|
|||||||
guid = "%s-%s" % (set_type, line[0])
|
guid = "%s-%s" % (set_type, line[0])
|
||||||
text_a = line[1]
|
text_a = line[1]
|
||||||
text_b = line[2]
|
text_b = line[2]
|
||||||
label = line[-1]
|
label = None if set_type == "test" else line[-1]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,10 @@ class DataProcessor:
|
|||||||
"""Gets a collection of `InputExample`s for the dev set."""
|
"""Gets a collection of `InputExample`s for the dev set."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the test set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""Gets the list of labels for this data set."""
|
"""Gets the list of labels for this data set."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
data_args = GlueDataTrainingArguments(
|
data_args = GlueDataTrainingArguments(
|
||||||
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
||||||
)
|
)
|
||||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||||
data_collator = DefaultDataCollator()
|
data_collator = DefaultDataCollator()
|
||||||
batch = data_collator.collate_batch(dataset.features)
|
batch = data_collator.collate_batch(dataset.features)
|
||||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||||
@@ -41,7 +41,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
data_args = GlueDataTrainingArguments(
|
data_args = GlueDataTrainingArguments(
|
||||||
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
||||||
)
|
)
|
||||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||||
data_collator = DefaultDataCollator()
|
data_collator = DefaultDataCollator()
|
||||||
batch = data_collator.collate_batch(dataset.features)
|
batch = data_collator.collate_batch(dataset.features)
|
||||||
self.assertEqual(batch["labels"].dtype, torch.float)
|
self.assertEqual(batch["labels"].dtype, torch.float)
|
||||||
@@ -93,7 +93,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
data_args = GlueDataTrainingArguments(
|
data_args = GlueDataTrainingArguments(
|
||||||
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
||||||
)
|
)
|
||||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
|
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||||
|
|
||||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||||
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
|
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user