@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
|||||||
Handles arguments for token classification.
|
Handles arguments for token classification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||||
|
|
||||||
if args is not None and len(args) > 0:
|
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
|
||||||
inputs = list(args)
|
inputs = list(inputs)
|
||||||
batch_size = len(inputs)
|
batch_size = len(inputs)
|
||||||
|
elif isinstance(inputs, str):
|
||||||
|
inputs = [inputs]
|
||||||
|
batch_size = 1
|
||||||
else:
|
else:
|
||||||
raise ValueError("At least one input is required.")
|
raise ValueError("At least one input is required.")
|
||||||
|
|
||||||
@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
Only exists if the offsets are available within the tokenizer
|
Only exists if the offsets are available within the tokenizer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||||
|
|
||||||
answers = []
|
answers = []
|
||||||
|
|
||||||
for i, sentence in enumerate(inputs):
|
for i, sentence in enumerate(_inputs):
|
||||||
|
|
||||||
# Manage correct placement of the tensors
|
# Manage correct placement of the tensors
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
|
|||||||
@@ -14,14 +14,17 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, pipeline
|
from transformers import AutoTokenizer, is_torch_available, pipeline
|
||||||
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
||||||
from transformers.testing_utils import require_tf, require_torch, slow
|
from transformers.testing_utils import require_tf, require_torch, slow
|
||||||
|
|
||||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
if is_torch_available():
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||||
|
|
||||||
|
|
||||||
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||||
@@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
|
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
|
||||||
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
|
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
|
||||||
|
sentence2 = "This is a simple test"
|
||||||
|
output = nlp(sentence)
|
||||||
|
|
||||||
def simplify(output):
|
def simplify(output):
|
||||||
for i in range(len(output)):
|
if isinstance(output, (list, tuple)):
|
||||||
output[i]["score"] = round(output[i]["score"], 3)
|
return [simplify(item) for item in output]
|
||||||
|
elif isinstance(output, dict):
|
||||||
|
return {simplify(k): simplify(v) for k, v in output.items()}
|
||||||
|
elif isinstance(output, (str, int, np.int64)):
|
||||||
return output
|
return output
|
||||||
|
elif isinstance(output, float):
|
||||||
|
return round(output, 3)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Cannot handle {type(output)}")
|
||||||
|
|
||||||
output = simplify(output)
|
output_ = simplify(output)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output,
|
output_,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"entity_group": "PER",
|
"entity_group": "PER",
|
||||||
@@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output = nlp([sentence, sentence2])
|
||||||
|
output_ = simplify(output)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
output_,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26},
|
||||||
|
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
|
||||||
|
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
@@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(inputs, [string])
|
self.assertEqual(inputs, [string])
|
||||||
self.assertEqual(offset_mapping, None)
|
self.assertEqual(offset_mapping, None)
|
||||||
|
|
||||||
inputs, offset_mapping = self.args_parser(string, string)
|
inputs, offset_mapping = self.args_parser([string, string])
|
||||||
self.assertEqual(inputs, [string, string])
|
self.assertEqual(inputs, [string, string])
|
||||||
self.assertEqual(offset_mapping, None)
|
self.assertEqual(offset_mapping, None)
|
||||||
|
|
||||||
@@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(inputs, [string])
|
self.assertEqual(inputs, [string])
|
||||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
||||||
|
|
||||||
inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
inputs, offset_mapping = self.args_parser(
|
||||||
|
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
|
||||||
|
)
|
||||||
self.assertEqual(inputs, [string, string])
|
self.assertEqual(inputs, [string, string])
|
||||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||||
|
|
||||||
def test_errors(self):
|
def test_errors(self):
|
||||||
string = "This is a simple input"
|
string = "This is a simple input"
|
||||||
|
|
||||||
# 2 sentences, 1 offset_mapping
|
# 2 sentences, 1 offset_mapping, args
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TypeError):
|
||||||
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
|
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
|
||||||
|
|
||||||
# 2 sentences, 1 offset_mapping
|
# 2 sentences, 1 offset_mapping, args
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TypeError):
|
||||||
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
|
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
|
||||||
|
|
||||||
|
# 2 sentences, 1 offset_mapping, input_list
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]])
|
||||||
|
|
||||||
|
# 2 sentences, 1 offset_mapping, input_list
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)])
|
||||||
|
|
||||||
# 1 sentences, 2 offset_mapping
|
# 1 sentences, 2 offset_mapping
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||||
|
|
||||||
# 0 sentences, 1 offset_mapping
|
# 0 sentences, 1 offset_mapping
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(TypeError):
|
||||||
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])
|
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])
|
||||||
|
|||||||
Reference in New Issue
Block a user