[Feature] Support is_split_into_words in the TokenClassificationPipeline. (#38818)

* some fixes

* some fixes

* now the pipeline can take list of tokens as input and is_split_into_words argument

* now the pipeline can take list of tokens as input and is_split_into_words argument

* now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input

* now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input

* solving test problems

* some fixes

* some fixes

* modify tests

* aligning start and end correctly

* adding tests

* some formatting

* some formatting

* some fixes

* some fixes

* some fixes

* resolve conflicts

* removing unimportant lines

* removing unimportant lines

* generalize to other languages

* generalize to other languages

* generalize to other languages

* generalize to other languages
This commit is contained in:
Yusuf Shihata
2025-06-23 18:31:32 +03:00
committed by GitHub
parent 2ce02b98bf
commit 9eac19eb59
2 changed files with 141 additions and 11 deletions

View File

@@ -308,6 +308,54 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
)
@require_torch
@slow
def test_is_split_into_words(self):
"""
Tests the pipeline with pre-tokenized inputs (is_split_into_words=True)
and validates that the character offsets are correct.
"""
token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
# Input is a list of words
words = ["Hello", "Sarah", "lives", "in", "New", "York"]
# The reconstructed sentence will be "Hello Sarah lives in New York"
# - "Sarah": starts at index 6, ends at 11
# - "New York": starts at index 21, ends at 29
output = token_classifier(words, is_split_into_words=True)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
],
)
# Also test batching with pre-tokenized inputs
words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"]
batch_output = token_classifier([words, words2], is_split_into_words=True)
# Expected for second sentence ("My name is Wolfgang and I live in Berlin")
# - "Wolfgang": starts at 12, ends at 20
# - "Berlin": starts at 36, ends at 42
self.assertEqual(
nested_simplify(batch_output),
[
[
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
],
[
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
],
],
)
@require_torch
def test_chunking_fast(self):
# Note: We cannot run the test on "conflicts" on the chunking.
@@ -953,19 +1001,24 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
def test_simple(self):
string = "This is a simple input"
inputs, offset_mapping = self.args_parser(string)
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string)
self.assertEqual(inputs, [string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser([string, string])
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string])
self.assertEqual(inputs, [string, string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
string, offset_mapping=[(0, 1), (1, 2)]
)
self.assertEqual(inputs, [string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
inputs, offset_mapping = self.args_parser(
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
)
self.assertEqual(inputs, [string, string])