From d29baf69bbaa712af25c8344020e12e455e80727 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Nov 2021 14:47:52 +0100 Subject: [PATCH] Fixing mishandling of `ignore_labels`. (#14274) Fixes #14272 --- src/transformers/pipelines/token_classification.py | 9 +++++---- tests/test_pipelines_token_classification.py | 9 +++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 49ede61dc8..6324327dcd 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -96,7 +96,6 @@ class TokenClassificationPipeline(Pipeline): default_input_names = "sequences" def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs): - self.ignore_labels = ["O"] super().__init__(*args, **kwargs) self.check_model_type( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING @@ -216,7 +215,9 @@ class TokenClassificationPipeline(Pipeline): **model_inputs, } - def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE): + def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None): + if ignore_labels is None: + ignore_labels = ["O"] logits = model_outputs["logits"][0].numpy() sentence = model_outputs["sentence"] input_ids = model_outputs["input_ids"][0] @@ -235,8 +236,8 @@ class TokenClassificationPipeline(Pipeline): entities = [ entity for entity in grouped_entities - if entity.get("entity", None) not in self.ignore_labels - and entity.get("entity_group", None) not in self.ignore_labels + if entity.get("entity", None) not in ignore_labels + and entity.get("entity_group", None) not in ignore_labels ] return entities diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index caeef47d95..cdb0709119 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -627,6 +627,15 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ], ) + token_classifier = pipeline( + task="token-classification", model=model_name, framework="pt", ignore_labels=["O", "I-MISC"] + ) + outputs = token_classifier("This is a test !") + self.assertEqual( + nested_simplify(outputs), + [], + ) + @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"