infer entailment label id on zero shot pipeline (#8059)
* add entailment dim argument * rename dim -> id * fix last name change, style * rm arg, auto-infer only * typo * rm superfluous import
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
|
||||
@@ -18,6 +19,24 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
||||
sum += score
|
||||
self.assertAlmostEqual(sum, 1.0)
|
||||
|
||||
def _test_entailment_id(self, nlp: Pipeline):
|
||||
config = nlp.model.config
|
||||
original_config = deepcopy(config)
|
||||
|
||||
config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}
|
||||
self.assertEqual(nlp.entailment_id, -1)
|
||||
|
||||
config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
|
||||
self.assertEqual(nlp.entailment_id, 0)
|
||||
|
||||
config.label2id = {"ENTAIL": 0, "NON-ENTAIL": 1}
|
||||
self.assertEqual(nlp.entailment_id, 0)
|
||||
|
||||
config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0}
|
||||
self.assertEqual(nlp.entailment_id, 2)
|
||||
|
||||
nlp.model.config = original_config
|
||||
|
||||
def _test_pipeline(self, nlp: Pipeline):
|
||||
output_keys = {"sequence", "labels", "scores"}
|
||||
valid_mono_inputs = [
|
||||
@@ -59,6 +78,8 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
||||
]
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
self._test_entailment_id(nlp)
|
||||
|
||||
for mono_input in valid_mono_inputs:
|
||||
mono_result = nlp(**mono_input)
|
||||
self.assertIsInstance(mono_result, dict)
|
||||
|
||||
Reference in New Issue
Block a user