From 3e58b6b7b8f1b7df121644972dc00278ece6c707 Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Tue, 27 Oct 2020 14:09:55 -0400 Subject: [PATCH] infer entailment label id on zero shot pipeline (#8059) * add entailment dim argument * rename dim -> id * fix last name change, style * rm arg, auto-infer only * typo * rm superfluous import --- src/transformers/pipelines.py | 25 ++++++++++++++++++++----- tests/test_pipelines_zero_shot.py | 21 +++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 4a7c42fd86..566fcce096 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1085,8 +1085,8 @@ class ZeroShotClassificationPipeline(Pipeline): Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis pair and passed to the pretrained model. Then, the logit for `entailment` is taken as the logit for the candidate - label being valid. Any NLI model can be used as long as the first output logit corresponds to `contradiction` and - the last to `entailment`. + label being valid. Any NLI model can be used, but the id of the `entailment` label must be included in the model + config's :attr:`~transformers.PretrainedConfig.label2id`. This NLI pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier: :obj:`"zero-shot-classification"`. @@ -1097,6 +1097,18 @@ class ZeroShotClassificationPipeline(Pipeline): def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs): super().__init__(*args, args_parser=args_parser, **kwargs) + if self.entailment_id == -1: + logger.warning( + "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to " + "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs." + ) + + @property + def entailment_id(self): + for label, ind in self.model.config.label2id.items(): + if label.lower().startswith("entail"): + return ind + return -1 def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): """ @@ -1115,7 +1127,8 @@ class ZeroShotClassificationPipeline(Pipeline): def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False): """ - Classify the sequence(s) given as inputs. + Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline` + documentation for more information. Args: sequences (:obj:`str` or :obj:`List[str]`): @@ -1153,11 +1166,13 @@ class ZeroShotClassificationPipeline(Pipeline): if not multi_class: # softmax the "entailment" logits over all candidate labels - entail_logits = reshaped_outputs[..., -1] + entail_logits = reshaped_outputs[..., self.entailment_id] scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True) else: # softmax over the entailment vs. contradiction dim for each label independently - entail_contr_logits = reshaped_outputs[..., [0, -1]] + entailment_id = self.entailment_id + contradiction_id = -1 if entailment_id == 0 else 0 + entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]] scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True) scores = scores[..., 1] diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index 42adfc27ce..a2d6c590de 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -1,4 +1,5 @@ import unittest +from copy import deepcopy from transformers.pipelines import Pipeline @@ -18,6 +19,24 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte sum += score self.assertAlmostEqual(sum, 1.0) + def _test_entailment_id(self, nlp: Pipeline): + config = nlp.model.config + original_config = deepcopy(config) + + config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2} + self.assertEqual(nlp.entailment_id, -1) + + config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2} + self.assertEqual(nlp.entailment_id, 0) + + config.label2id = {"ENTAIL": 0, "NON-ENTAIL": 1} + self.assertEqual(nlp.entailment_id, 0) + + config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0} + self.assertEqual(nlp.entailment_id, 2) + + nlp.model.config = original_config + def _test_pipeline(self, nlp: Pipeline): output_keys = {"sequence", "labels", "scores"} valid_mono_inputs = [ @@ -59,6 +78,8 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte ] self.assertIsNotNone(nlp) + self._test_entailment_id(nlp) + for mono_input in valid_mono_inputs: mono_result = nlp(**mono_input) self.assertIsInstance(mono_result, dict)