Better heuristic for token-classification pipeline. (#12611)
* Better heuristic for token-classification pipeline. Relooking at the problem makes thing actually much simpler, when we look at ids from a tokenizer, we have no way in **general** to recover if some substring is part of a word or not. However, within the pipeline, with offsets we still have access to the original string, so we can simply look if previous character (if it exists) of a token, is actually a space. This will obviously be wrong for tokenizers that contain spaces within tokens, tokenizers where offsets include spaces too (Don't think there are a lot). This heuristic hopefully is fully bc and still can handle non-word based tokenizers. * Updating test with real values. * We still need the older "correct" heuristic to prevent fusing punctuation. * Adding a real warning when important.
This commit is contained in:
@@ -270,7 +270,19 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
if offset_mapping is not None:
|
if offset_mapping is not None:
|
||||||
start_ind, end_ind = offset_mapping[idx]
|
start_ind, end_ind = offset_mapping[idx]
|
||||||
word_ref = sentence[start_ind:end_ind]
|
word_ref = sentence[start_ind:end_ind]
|
||||||
is_subword = len(word_ref) != len(word)
|
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
|
||||||
|
# This is a BPE, word aware tokenizer, there is a correct way
|
||||||
|
# to fuse tokens
|
||||||
|
is_subword = len(word) != len(word_ref)
|
||||||
|
else:
|
||||||
|
# This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
|
||||||
|
if self.aggregation_strategy in {
|
||||||
|
AggregationStrategy.FIRST,
|
||||||
|
AggregationStrategy.AVERAGE,
|
||||||
|
AggregationStrategy.MAX,
|
||||||
|
}:
|
||||||
|
warnings.warn(UserWarning, "Tokenizer does not support real words, using fallback heuristic")
|
||||||
|
is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False
|
||||||
|
|
||||||
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
||||||
word = word_ref
|
word = word_ref
|
||||||
|
|||||||
@@ -191,6 +191,19 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_aggregation_strategy_byte_level_tokenizer(self):
|
||||||
|
sentence = "Groenlinks praat over Schiphol."
|
||||||
|
ner = pipeline("ner", model="xlm-roberta-large-finetuned-conll02-dutch", aggregation_strategy="max")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(ner(sentence)),
|
||||||
|
[
|
||||||
|
{"end": 10, "entity_group": "ORG", "score": 0.994, "start": 0, "word": "Groenlinks"},
|
||||||
|
{"entity_group": "LOC", "score": 1.0, "word": "Schiphol.", "start": 22, "end": 31},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_aggregation_strategy(self):
|
def test_aggregation_strategy(self):
|
||||||
model_name = self.small_models[0]
|
model_name = self.small_models[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user