infer entailment label id on zero shot pipeline (#8059)
* add entailment dim argument * rename dim -> id * fix last name change, style * rm arg, auto-infer only * typo * rm superfluous import
This commit is contained in:
@@ -1085,8 +1085,8 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
|
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
|
||||||
pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the candidate
|
pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the candidate
|
||||||
label being valid. Any NLI model can be used as long as the first output logit corresponds to `contradiction` and
|
label being valid. Any NLI model can be used, but the id of the `entailment` label must be included in the model
|
||||||
the last to `entailment`.
|
config's :attr:`~transformers.PretrainedConfig.label2id`.
|
||||||
|
|
||||||
This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier:
|
This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier:
|
||||||
:obj:`"zero-shot-classification"`.
|
:obj:`"zero-shot-classification"`.
|
||||||
@@ -1097,6 +1097,18 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
|
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
|
||||||
super().__init__(*args, args_parser=args_parser, **kwargs)
|
super().__init__(*args, args_parser=args_parser, **kwargs)
|
||||||
|
if self.entailment_id == -1:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
|
||||||
|
"-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entailment_id(self):
|
||||||
|
for label, ind in self.model.config.label2id.items():
|
||||||
|
if label.lower().startswith("entail"):
|
||||||
|
return ind
|
||||||
|
return -1
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -1115,7 +1127,8 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
|
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
|
||||||
"""
|
"""
|
||||||
Classify the sequence(s) given as inputs.
|
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
|
||||||
|
documentation for more information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`str` or :obj:`List[str]`):
|
sequences (:obj:`str` or :obj:`List[str]`):
|
||||||
@@ -1153,11 +1166,13 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
if not multi_class:
|
if not multi_class:
|
||||||
# softmax the "entailment" logits over all candidate labels
|
# softmax the "entailment" logits over all candidate labels
|
||||||
entail_logits = reshaped_outputs[..., -1]
|
entail_logits = reshaped_outputs[..., self.entailment_id]
|
||||||
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
|
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
|
||||||
else:
|
else:
|
||||||
# softmax over the entailment vs. contradiction dim for each label independently
|
# softmax over the entailment vs. contradiction dim for each label independently
|
||||||
entail_contr_logits = reshaped_outputs[..., [0, -1]]
|
entailment_id = self.entailment_id
|
||||||
|
contradiction_id = -1 if entailment_id == 0 else 0
|
||||||
|
entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
|
||||||
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
|
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
|
||||||
scores = scores[..., 1]
|
scores = scores[..., 1]
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from transformers.pipelines import Pipeline
|
from transformers.pipelines import Pipeline
|
||||||
|
|
||||||
@@ -18,6 +19,24 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
|||||||
sum += score
|
sum += score
|
||||||
self.assertAlmostEqual(sum, 1.0)
|
self.assertAlmostEqual(sum, 1.0)
|
||||||
|
|
||||||
|
def _test_entailment_id(self, nlp: Pipeline):
|
||||||
|
config = nlp.model.config
|
||||||
|
original_config = deepcopy(config)
|
||||||
|
|
||||||
|
config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}
|
||||||
|
self.assertEqual(nlp.entailment_id, -1)
|
||||||
|
|
||||||
|
config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
|
||||||
|
self.assertEqual(nlp.entailment_id, 0)
|
||||||
|
|
||||||
|
config.label2id = {"ENTAIL": 0, "NON-ENTAIL": 1}
|
||||||
|
self.assertEqual(nlp.entailment_id, 0)
|
||||||
|
|
||||||
|
config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0}
|
||||||
|
self.assertEqual(nlp.entailment_id, 2)
|
||||||
|
|
||||||
|
nlp.model.config = original_config
|
||||||
|
|
||||||
def _test_pipeline(self, nlp: Pipeline):
|
def _test_pipeline(self, nlp: Pipeline):
|
||||||
output_keys = {"sequence", "labels", "scores"}
|
output_keys = {"sequence", "labels", "scores"}
|
||||||
valid_mono_inputs = [
|
valid_mono_inputs = [
|
||||||
@@ -59,6 +78,8 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
|||||||
]
|
]
|
||||||
self.assertIsNotNone(nlp)
|
self.assertIsNotNone(nlp)
|
||||||
|
|
||||||
|
self._test_entailment_id(nlp)
|
||||||
|
|
||||||
for mono_input in valid_mono_inputs:
|
for mono_input in valid_mono_inputs:
|
||||||
mono_result = nlp(**mono_input)
|
mono_result = nlp(**mono_input)
|
||||||
self.assertIsInstance(mono_result, dict)
|
self.assertIsInstance(mono_result, dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user