correct label extraction + add note on discrepancies on trained MNLI model and HANS (#6221)
This commit is contained in:
@@ -255,7 +255,11 @@ class HansProcessor(DataProcessor):
|
|||||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev")
|
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
"""See base class."""
|
"""See base class.
|
||||||
|
Note that we follow the standard three labels for MNLI
|
||||||
|
(see :class:`~transformers.data.processors.utils.MnliProcessor`)
|
||||||
|
but the HANS evaluation groups `contradiction` and `neutral` into `non-entailment` (label 0) while
|
||||||
|
`entailment` is label 1."""
|
||||||
return ["contradiction", "entailment", "neutral"]
|
return ["contradiction", "entailment", "neutral"]
|
||||||
|
|
||||||
def _create_examples(self, lines, set_type):
|
def _create_examples(self, lines, set_type):
|
||||||
@@ -268,7 +272,7 @@ class HansProcessor(DataProcessor):
|
|||||||
text_a = line[5]
|
text_a = line[5]
|
||||||
text_b = line[6]
|
text_b = line[6]
|
||||||
pairID = line[7][2:] if line[7].startswith("ex") else line[7]
|
pairID = line[7][2:] if line[7].startswith("ex") else line[7]
|
||||||
label = line[-1]
|
label = line[0]
|
||||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
|
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user