updated data processor and metrics
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -130,5 +130,5 @@ runs
|
|||||||
examples/runs
|
examples/runs
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data
|
/data
|
||||||
serialization_dir
|
serialization_dir
|
||||||
@@ -46,7 +46,10 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
|||||||
|
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from pytorch_transformers.preprocessing import (compute_metrics, output_modes, processors, convert_examples_to_glue_features)
|
from pytorch_transformers import glue_compute_metrics as compute_metrics
|
||||||
|
from pytorch_transformers import glue_output_modes as output_modes
|
||||||
|
from pytorch_transformers import glue_processors as processors
|
||||||
|
from pytorch_transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -275,7 +278,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
features = convert_examples_to_glue_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||||
|
|||||||
@@ -73,3 +73,10 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
|
|||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||||
|
|
||||||
|
from .data import (is_sklearn_available,
|
||||||
|
InputExample, InputFeatures, DataProcessor,
|
||||||
|
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
||||||
|
|
||||||
|
if is_sklearn_available():
|
||||||
|
from .data import glue_compute_metrics
|
||||||
|
|||||||
6
pytorch_transformers/data/__init__.py
Normal file
6
pytorch_transformers/data/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .processors import (InputExample, InputFeatures, DataProcessor,
|
||||||
|
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
||||||
|
from .metrics import is_sklearn_available
|
||||||
|
|
||||||
|
if is_sklearn_available():
|
||||||
|
from .metrics import glue_compute_metrics
|
||||||
83
pytorch_transformers/data/metrics/__init__.py
Normal file
83
pytorch_transformers/data/metrics/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from scipy.stats import pearsonr, spearmanr
|
||||||
|
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||||
|
_has_sklearn = True
|
||||||
|
except e:
|
||||||
|
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
|
||||||
|
_has_sklearn = False
|
||||||
|
|
||||||
|
def is_sklearn_available():
|
||||||
|
return _has_sklearn
|
||||||
|
|
||||||
|
if _has_sklearn:
|
||||||
|
|
||||||
|
def simple_accuracy(preds, labels):
|
||||||
|
return (preds == labels).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def acc_and_f1(preds, labels):
|
||||||
|
acc = simple_accuracy(preds, labels)
|
||||||
|
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||||
|
return {
|
||||||
|
"acc": acc,
|
||||||
|
"f1": f1,
|
||||||
|
"acc_and_f1": (acc + f1) / 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def pearson_and_spearman(preds, labels):
|
||||||
|
pearson_corr = pearsonr(preds, labels)[0]
|
||||||
|
spearman_corr = spearmanr(preds, labels)[0]
|
||||||
|
return {
|
||||||
|
"pearson": pearson_corr,
|
||||||
|
"spearmanr": spearman_corr,
|
||||||
|
"corr": (pearson_corr + spearman_corr) / 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def glue_compute_metrics(task_name, preds, labels):
|
||||||
|
assert len(preds) == len(labels)
|
||||||
|
if task_name == "cola":
|
||||||
|
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||||
|
elif task_name == "sst-2":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
elif task_name == "mrpc":
|
||||||
|
return acc_and_f1(preds, labels)
|
||||||
|
elif task_name == "sts-b":
|
||||||
|
return pearson_and_spearman(preds, labels)
|
||||||
|
elif task_name == "qqp":
|
||||||
|
return acc_and_f1(preds, labels)
|
||||||
|
elif task_name == "mnli":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
elif task_name == "mnli-mm":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
elif task_name == "qnli":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
elif task_name == "rte":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
elif task_name == "wnli":
|
||||||
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
|
else:
|
||||||
|
raise KeyError(task_name)
|
||||||
2
pytorch_transformers/data/processors/__init__.py
Normal file
2
pytorch_transformers/data/processors/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .utils import InputExample, InputFeatures, DataProcessor
|
||||||
|
from .glue import output_modes, processors, convert_examples_to_glue_features
|
||||||
@@ -15,12 +15,50 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" GLUE processors and helpers """
|
""" GLUE processors and helpers """
|
||||||
|
|
||||||
from .utils import DataProcessor
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
|
"cola": 2,
|
||||||
|
"mnli": 3,
|
||||||
|
"mrpc": 2,
|
||||||
|
"sst-2": 2,
|
||||||
|
"sts-b": 1,
|
||||||
|
"qqp": 2,
|
||||||
|
"qnli": 2,
|
||||||
|
"rte": 2,
|
||||||
|
"wnli": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
processors = {
|
||||||
|
"cola": ColaProcessor,
|
||||||
|
"mnli": MnliProcessor,
|
||||||
|
"mnli-mm": MnliMismatchedProcessor,
|
||||||
|
"mrpc": MrpcProcessor,
|
||||||
|
"sst-2": Sst2Processor,
|
||||||
|
"sts-b": StsbProcessor,
|
||||||
|
"qqp": QqpProcessor,
|
||||||
|
"qnli": QnliProcessor,
|
||||||
|
"rte": RteProcessor,
|
||||||
|
"wnli": WnliProcessor,
|
||||||
|
}
|
||||||
|
|
||||||
|
output_modes = {
|
||||||
|
"cola": "classification",
|
||||||
|
"mnli": "classification",
|
||||||
|
"mnli-mm": "classification",
|
||||||
|
"mrpc": "classification",
|
||||||
|
"sst-2": "classification",
|
||||||
|
"sts-b": "regression",
|
||||||
|
"qqp": "classification",
|
||||||
|
"qnli": "classification",
|
||||||
|
"rte": "classification",
|
||||||
|
"wnli": "classification",
|
||||||
|
}
|
||||||
|
|
||||||
def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
||||||
tokenizer, output_mode,
|
tokenizer, output_mode,
|
||||||
@@ -91,37 +129,6 @@ def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
class InputExample(object):
|
|
||||||
"""A single training/test example for simple sequence classification."""
|
|
||||||
|
|
||||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
|
||||||
"""Constructs a InputExample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guid: Unique id for the example.
|
|
||||||
text_a: string. The untokenized text of the first sequence. For single
|
|
||||||
sequence tasks, only this sequence must be specified.
|
|
||||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
|
||||||
Only must be specified for sequence pair tasks.
|
|
||||||
label: (Optional) string. The label of the example. This should be
|
|
||||||
specified for train and dev examples, but not for test examples.
|
|
||||||
"""
|
|
||||||
self.guid = guid
|
|
||||||
self.text_a = text_a
|
|
||||||
self.text_b = text_b
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
|
|
||||||
class InputFeatures(object):
|
|
||||||
"""A single set of features of data."""
|
|
||||||
|
|
||||||
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
|
||||||
self.input_ids = input_ids
|
|
||||||
self.input_mask = input_mask
|
|
||||||
self.segment_ids = segment_ids
|
|
||||||
self.label_id = label_id
|
|
||||||
|
|
||||||
|
|
||||||
class MrpcProcessor(DataProcessor):
|
class MrpcProcessor(DataProcessor):
|
||||||
"""Processor for the MRPC data set (GLUE version)."""
|
"""Processor for the MRPC data set (GLUE version)."""
|
||||||
|
|
||||||
@@ -420,15 +427,3 @@ class WnliProcessor(DataProcessor):
|
|||||||
examples.append(
|
examples.append(
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
|
||||||
"cola": 2,
|
|
||||||
"mnli": 3,
|
|
||||||
"mrpc": 2,
|
|
||||||
"sst-2": 2,
|
|
||||||
"sts-b": 1,
|
|
||||||
"qqp": 2,
|
|
||||||
"qnli": 2,
|
|
||||||
"rte": 2,
|
|
||||||
"wnli": 2,
|
|
||||||
}
|
|
||||||
@@ -17,8 +17,34 @@
|
|||||||
import csv
|
import csv
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from scipy.stats import pearsonr, spearmanr
|
class InputExample(object):
|
||||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
"""A single training/test example for simple sequence classification."""
|
||||||
|
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||||
|
"""Constructs a InputExample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guid: Unique id for the example.
|
||||||
|
text_a: string. The untokenized text of the first sequence. For single
|
||||||
|
sequence tasks, only this sequence must be specified.
|
||||||
|
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||||
|
Only must be specified for sequence pair tasks.
|
||||||
|
label: (Optional) string. The label of the example. This should be
|
||||||
|
specified for train and dev examples, but not for test examples.
|
||||||
|
"""
|
||||||
|
self.guid = guid
|
||||||
|
self.text_a = text_a
|
||||||
|
self.text_b = text_b
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
|
||||||
|
class InputFeatures(object):
|
||||||
|
"""A single set of features of data."""
|
||||||
|
|
||||||
|
def __init__(self, input_ids, input_mask, segment_ids, label_id):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.input_mask = input_mask
|
||||||
|
self.segment_ids = segment_ids
|
||||||
|
self.label_id = label_id
|
||||||
|
|
||||||
|
|
||||||
class DataProcessor(object):
|
class DataProcessor(object):
|
||||||
@@ -47,53 +73,3 @@ class DataProcessor(object):
|
|||||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def simple_accuracy(preds, labels):
|
|
||||||
return (preds == labels).mean()
|
|
||||||
|
|
||||||
|
|
||||||
def acc_and_f1(preds, labels):
|
|
||||||
acc = simple_accuracy(preds, labels)
|
|
||||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
|
||||||
return {
|
|
||||||
"acc": acc,
|
|
||||||
"f1": f1,
|
|
||||||
"acc_and_f1": (acc + f1) / 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def pearson_and_spearman(preds, labels):
|
|
||||||
pearson_corr = pearsonr(preds, labels)[0]
|
|
||||||
spearman_corr = spearmanr(preds, labels)[0]
|
|
||||||
return {
|
|
||||||
"pearson": pearson_corr,
|
|
||||||
"spearmanr": spearman_corr,
|
|
||||||
"corr": (pearson_corr + spearman_corr) / 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics(task_name, preds, labels):
|
|
||||||
assert len(preds) == len(labels)
|
|
||||||
if task_name == "cola":
|
|
||||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
|
||||||
elif task_name == "sst-2":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
elif task_name == "mrpc":
|
|
||||||
return acc_and_f1(preds, labels)
|
|
||||||
elif task_name == "sts-b":
|
|
||||||
return pearson_and_spearman(preds, labels)
|
|
||||||
elif task_name == "qqp":
|
|
||||||
return acc_and_f1(preds, labels)
|
|
||||||
elif task_name == "mnli":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
elif task_name == "mnli-mm":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
elif task_name == "qnli":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
elif task_name == "rte":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
elif task_name == "wnli":
|
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
|
||||||
else:
|
|
||||||
raise KeyError(task_name)
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
||||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from .glue import (ColaProcessor,
|
|
||||||
MnliProcessor,
|
|
||||||
MnliMismatchedProcessor,
|
|
||||||
MrpcProcessor,
|
|
||||||
Sst2Processor,
|
|
||||||
StsbProcessor,
|
|
||||||
QqpProcessor,
|
|
||||||
QnliProcessor,
|
|
||||||
RteProcessor,
|
|
||||||
WnliProcessor,
|
|
||||||
convert_examples_to_glue_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import DataProcessor, simple_accuracy, acc_and_f1, pearson_and_spearman, compute_metrics
|
|
||||||
|
|
||||||
processors = {
|
|
||||||
"cola": ColaProcessor,
|
|
||||||
"mnli": MnliProcessor,
|
|
||||||
"mnli-mm": MnliMismatchedProcessor,
|
|
||||||
"mrpc": MrpcProcessor,
|
|
||||||
"sst-2": Sst2Processor,
|
|
||||||
"sts-b": StsbProcessor,
|
|
||||||
"qqp": QqpProcessor,
|
|
||||||
"qnli": QnliProcessor,
|
|
||||||
"rte": RteProcessor,
|
|
||||||
"wnli": WnliProcessor,
|
|
||||||
}
|
|
||||||
|
|
||||||
output_modes = {
|
|
||||||
"cola": "classification",
|
|
||||||
"mnli": "classification",
|
|
||||||
"mnli-mm": "classification",
|
|
||||||
"mrpc": "classification",
|
|
||||||
"sst-2": "classification",
|
|
||||||
"sts-b": "regression",
|
|
||||||
"qqp": "classification",
|
|
||||||
"qnli": "classification",
|
|
||||||
"rte": "classification",
|
|
||||||
"wnli": "classification",
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user