Better heuristic for token-classification pipeline. (#12611)
* Better heuristic for token-classification pipeline. Relooking at the problem makes thing actually much simpler, when we look at ids from a tokenizer, we have no way in **general** to recover if some substring is part of a word or not. However, within the pipeline, with offsets we still have access to the original string, so we can simply look if previous character (if it exists) of a token, is actually a space. This will obviously be wrong for tokenizers that contain spaces within tokens, tokenizers where offsets include spaces too (Don't think there are a lot). This heuristic hopefully is fully bc and still can handle non-word based tokenizers. * Updating test with real values. * We still need the older "correct" heuristic to prevent fusing punctuation. * Adding a real warning when important.
This commit is contained in:
@@ -270,7 +270,19 @@ class TokenClassificationPipeline(Pipeline):
|
||||
if offset_mapping is not None:
|
||||
start_ind, end_ind = offset_mapping[idx]
|
||||
word_ref = sentence[start_ind:end_ind]
|
||||
is_subword = len(word_ref) != len(word)
|
||||
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
|
||||
# This is a BPE, word aware tokenizer, there is a correct way
|
||||
# to fuse tokens
|
||||
is_subword = len(word) != len(word_ref)
|
||||
else:
|
||||
# This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
|
||||
if self.aggregation_strategy in {
|
||||
AggregationStrategy.FIRST,
|
||||
AggregationStrategy.AVERAGE,
|
||||
AggregationStrategy.MAX,
|
||||
}:
|
||||
warnings.warn(UserWarning, "Tokenizer does not support real words, using fallback heuristic")
|
||||
is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False
|
||||
|
||||
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
||||
word = word_ref
|
||||
|
||||
@@ -191,6 +191,19 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_aggregation_strategy_byte_level_tokenizer(self):
|
||||
sentence = "Groenlinks praat over Schiphol."
|
||||
ner = pipeline("ner", model="xlm-roberta-large-finetuned-conll02-dutch", aggregation_strategy="max")
|
||||
self.assertEqual(
|
||||
nested_simplify(ner(sentence)),
|
||||
[
|
||||
{"end": 10, "entity_group": "ORG", "score": 0.994, "start": 0, "word": "Groenlinks"},
|
||||
{"entity_group": "LOC", "score": 1.0, "word": "Schiphol.", "start": 22, "end": 31},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy(self):
|
||||
model_name = self.small_models[0]
|
||||
|
||||
Reference in New Issue
Block a user