From a3bd7637322b5928409da695586eefb482c0c9f0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Jul 2021 16:21:26 +0200 Subject: [PATCH] Better heuristic for token-classification pipeline. (#12611) * Better heuristic for token-classification pipeline. Relooking at the problem makes thing actually much simpler, when we look at ids from a tokenizer, we have no way in **general** to recover if some substring is part of a word or not. However, within the pipeline, with offsets we still have access to the original string, so we can simply look if previous character (if it exists) of a token, is actually a space. This will obviously be wrong for tokenizers that contain spaces within tokens, tokenizers where offsets include spaces too (Don't think there are a lot). This heuristic hopefully is fully bc and still can handle non-word based tokenizers. * Updating test with real values. * We still need the older "correct" heuristic to prevent fusing punctuation. * Adding a real warning when important. --- src/transformers/pipelines/token_classification.py | 14 +++++++++++++- tests/test_pipelines_token_classification.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 3d155dcbfe..56025e8641 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -270,7 +270,19 @@ class TokenClassificationPipeline(Pipeline): if offset_mapping is not None: start_ind, end_ind = offset_mapping[idx] word_ref = sentence[start_ind:end_ind] - is_subword = len(word_ref) != len(word) + if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None): + # This is a BPE, word aware tokenizer, there is a correct way + # to fuse tokens + is_subword = len(word) != len(word_ref) + else: + # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately. + if self.aggregation_strategy in { + AggregationStrategy.FIRST, + AggregationStrategy.AVERAGE, + AggregationStrategy.MAX, + }: + warnings.warn(UserWarning, "Tokenizer does not support real words, using fallback heuristic") + is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False if int(input_ids[idx]) == self.tokenizer.unk_token_id: word = word_ref diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index ce33656314..9d27f95896 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -191,6 +191,19 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ], ) + @require_torch + @slow + def test_aggregation_strategy_byte_level_tokenizer(self): + sentence = "Groenlinks praat over Schiphol." + ner = pipeline("ner", model="xlm-roberta-large-finetuned-conll02-dutch", aggregation_strategy="max") + self.assertEqual( + nested_simplify(ner(sentence)), + [ + {"end": 10, "entity_group": "ORG", "score": 0.994, "start": 0, "word": "Groenlinks"}, + {"entity_group": "LOC", "score": 1.0, "word": "Schiphol.", "start": 22, "end": 31}, + ], + ) + @require_torch def test_aggregation_strategy(self): model_name = self.small_models[0]