move xnli processor (and utils) to transformers/data/processors
This commit is contained in:
committed by
Lysandre Debut
parent
289cf4d2b7
commit
d5910b312f
@@ -25,7 +25,8 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels)
|
||||
glue_processors, glue_tasks_num_labels,
|
||||
xnli_output_modes, xnli_processors, xnli_tasks_num_labels)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .processors import InputExample, InputFeatures, DataProcessor
|
||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||
|
||||
from .metrics import is_sklearn_available
|
||||
if is_sklearn_available():
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .utils import InputExample, InputFeatures, DataProcessor
|
||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
||||
|
||||
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||
|
||||
93
transformers/data/processors/xnli.py
Normal file
93
transformers/data/processors/xnli.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" XNLI utils (dataset loading and evaluation) """
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .utils import DataProcessor, InputExample
|
||||
from transformers.data.metrics import simple_accuracy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class XnliProcessor(DataProcessor):
|
||||
"""Processor for the XNLI dataset.
|
||||
Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
|
||||
|
||||
def __init__(self, language, train_language = None):
|
||||
self.language = language
|
||||
self.train_language = train_language
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
lg = self.language if self.train_language is None else self.train_language
|
||||
lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % ('train', i)
|
||||
text_a = line[0]
|
||||
text_b = line[1]
|
||||
label = "contradiction" if line[2] == "contradictory" else line[2]
|
||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
language = line[0]
|
||||
if language != self.language:
|
||||
continue
|
||||
guid = "%s-%s" % ('test', i)
|
||||
text_a = line[6]
|
||||
text_b = line[7]
|
||||
label = line[1]
|
||||
assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["contradiction", "entailment", "neutral"]
|
||||
|
||||
def xnli_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "xnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise ValueError(f'{task_name} is not a supported task.')
|
||||
|
||||
xnli_processors = {
|
||||
"xnli": XnliProcessor,
|
||||
}
|
||||
|
||||
xnli_output_modes = {
|
||||
"xnli": "classification",
|
||||
}
|
||||
|
||||
xnli_tasks_num_labels = {
|
||||
"xnli": 3,
|
||||
}
|
||||
Reference in New Issue
Block a user