[Feature] Support is_split_into_words in the TokenClassificationPipeline. (#38818)
* some fixes * some fixes * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * solving test problems * some fixes * some fixes * modify tests * aligning start and end correctly * adding tests * some formatting * some formatting * some fixes * some fixes * some fixes * resolve conflicts * removing unimportant lines * removing unimportant lines * generalize to other languages * generalize to other languages * generalize to other languages * generalize to other languages
This commit is contained in:
@@ -308,6 +308,54 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_is_split_into_words(self):
|
||||
"""
|
||||
Tests the pipeline with pre-tokenized inputs (is_split_into_words=True)
|
||||
and validates that the character offsets are correct.
|
||||
"""
|
||||
token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
|
||||
|
||||
# Input is a list of words
|
||||
words = ["Hello", "Sarah", "lives", "in", "New", "York"]
|
||||
|
||||
# The reconstructed sentence will be "Hello Sarah lives in New York"
|
||||
# - "Sarah": starts at index 6, ends at 11
|
||||
# - "New York": starts at index 21, ends at 29
|
||||
|
||||
output = token_classifier(words, is_split_into_words=True)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||
],
|
||||
)
|
||||
|
||||
# Also test batching with pre-tokenized inputs
|
||||
words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"]
|
||||
batch_output = token_classifier([words, words2], is_split_into_words=True)
|
||||
|
||||
# Expected for second sentence ("My name is Wolfgang and I live in Berlin")
|
||||
# - "Wolfgang": starts at 12, ends at 20
|
||||
# - "Berlin": starts at 36, ends at 42
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(batch_output),
|
||||
[
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_chunking_fast(self):
|
||||
# Note: We cannot run the test on "conflicts" on the chunking.
|
||||
@@ -953,19 +1001,24 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
string = "This is a simple input"
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string)
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string)
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
inputs, offset_mapping = self.args_parser([string, string])
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string])
|
||||
self.assertEqual(inputs, [string, string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
|
||||
string, offset_mapping=[(0, 1), (1, 2)]
|
||||
)
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
||||
|
||||
inputs, offset_mapping = self.args_parser(
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
|
||||
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
|
||||
)
|
||||
self.assertEqual(inputs, [string, string])
|
||||
|
||||
Reference in New Issue
Block a user