update data processors __init__
This commit is contained in:
@@ -76,7 +76,7 @@ from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CAC
|
|||||||
|
|
||||||
from .data import (is_sklearn_available,
|
from .data import (is_sklearn_available,
|
||||||
InputExample, InputFeatures, DataProcessor,
|
InputExample, InputFeatures, DataProcessor,
|
||||||
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
glue_output_modes, glue_convert_examples_to_features, glue_processors, glue_tasks_num_labels)
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .data import glue_compute_metrics
|
from .data import glue_compute_metrics
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from .processors import (InputExample, InputFeatures, DataProcessor,
|
from .processors import InputExample, InputFeatures, DataProcessor
|
||||||
glue_output_modes, glue_convert_examples_to_features, glue_processors)
|
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||||
from .metrics import is_sklearn_available
|
|
||||||
|
|
||||||
|
from .metrics import is_sklearn_available
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .metrics import glue_compute_metrics
|
from .metrics import glue_compute_metrics
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .utils import InputExample, InputFeatures, DataProcessor
|
from .utils import InputExample, InputFeatures, DataProcessor
|
||||||
from .glue import output_modes, processors, convert_examples_to_glue_features
|
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||||
|
|
||||||
|
|||||||
@@ -22,45 +22,7 @@ from .utils import DataProcessor, InputExample, InputFeatures
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
def glue_convert_examples_to_features(examples, label_list, max_seq_length,
|
||||||
"cola": 2,
|
|
||||||
"mnli": 3,
|
|
||||||
"mrpc": 2,
|
|
||||||
"sst-2": 2,
|
|
||||||
"sts-b": 1,
|
|
||||||
"qqp": 2,
|
|
||||||
"qnli": 2,
|
|
||||||
"rte": 2,
|
|
||||||
"wnli": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
processors = {
|
|
||||||
"cola": ColaProcessor,
|
|
||||||
"mnli": MnliProcessor,
|
|
||||||
"mnli-mm": MnliMismatchedProcessor,
|
|
||||||
"mrpc": MrpcProcessor,
|
|
||||||
"sst-2": Sst2Processor,
|
|
||||||
"sts-b": StsbProcessor,
|
|
||||||
"qqp": QqpProcessor,
|
|
||||||
"qnli": QnliProcessor,
|
|
||||||
"rte": RteProcessor,
|
|
||||||
"wnli": WnliProcessor,
|
|
||||||
}
|
|
||||||
|
|
||||||
output_modes = {
|
|
||||||
"cola": "classification",
|
|
||||||
"mnli": "classification",
|
|
||||||
"mnli-mm": "classification",
|
|
||||||
"mrpc": "classification",
|
|
||||||
"sst-2": "classification",
|
|
||||||
"sts-b": "regression",
|
|
||||||
"qqp": "classification",
|
|
||||||
"qnli": "classification",
|
|
||||||
"rte": "classification",
|
|
||||||
"wnli": "classification",
|
|
||||||
}
|
|
||||||
|
|
||||||
def convert_examples_to_glue_features(examples, label_list, max_seq_length,
|
|
||||||
tokenizer, output_mode,
|
tokenizer, output_mode,
|
||||||
pad_on_left=False,
|
pad_on_left=False,
|
||||||
pad_token=0,
|
pad_token=0,
|
||||||
@@ -427,3 +389,41 @@ class WnliProcessor(DataProcessor):
|
|||||||
examples.append(
|
examples.append(
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
glue_tasks_num_labels = {
|
||||||
|
"cola": 2,
|
||||||
|
"mnli": 3,
|
||||||
|
"mrpc": 2,
|
||||||
|
"sst-2": 2,
|
||||||
|
"sts-b": 1,
|
||||||
|
"qqp": 2,
|
||||||
|
"qnli": 2,
|
||||||
|
"rte": 2,
|
||||||
|
"wnli": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
glue_processors = {
|
||||||
|
"cola": ColaProcessor,
|
||||||
|
"mnli": MnliProcessor,
|
||||||
|
"mnli-mm": MnliMismatchedProcessor,
|
||||||
|
"mrpc": MrpcProcessor,
|
||||||
|
"sst-2": Sst2Processor,
|
||||||
|
"sts-b": StsbProcessor,
|
||||||
|
"qqp": QqpProcessor,
|
||||||
|
"qnli": QnliProcessor,
|
||||||
|
"rte": RteProcessor,
|
||||||
|
"wnli": WnliProcessor,
|
||||||
|
}
|
||||||
|
|
||||||
|
glue_output_modes = {
|
||||||
|
"cola": "classification",
|
||||||
|
"mnli": "classification",
|
||||||
|
"mnli-mm": "classification",
|
||||||
|
"mrpc": "classification",
|
||||||
|
"sst-2": "classification",
|
||||||
|
"sts-b": "regression",
|
||||||
|
"qqp": "classification",
|
||||||
|
"qnli": "classification",
|
||||||
|
"rte": "classification",
|
||||||
|
"wnli": "classification",
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user