Patch token classification pipeline (#8364)
* Patch token classification pipeline * Some added tests for TokenClassificationArgumentHandler (#8366) Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -1333,18 +1333,17 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|
||||||
if args is not None and len(args) > 0:
|
if args is not None and len(args) > 0:
|
||||||
if isinstance(args, str):
|
inputs = list(args)
|
||||||
inputs = [args]
|
|
||||||
else:
|
|
||||||
inputs = args
|
|
||||||
batch_size = len(inputs)
|
batch_size = len(inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError("At least one input is required.")
|
||||||
|
|
||||||
offset_mapping = kwargs.get("offset_mapping", None)
|
offset_mapping = kwargs.get("offset_mapping")
|
||||||
if offset_mapping:
|
if offset_mapping:
|
||||||
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
|
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
|
||||||
offset_mapping = [offset_mapping]
|
offset_mapping = [offset_mapping]
|
||||||
if len(offset_mapping) != batch_size:
|
if len(offset_mapping) != batch_size:
|
||||||
raise ("offset_mapping should have the same batch size as the input")
|
raise ValueError("offset_mapping should have the same batch size as the input")
|
||||||
return inputs, offset_mapping
|
return inputs, offset_mapping
|
||||||
|
|
||||||
|
|
||||||
@@ -1379,20 +1378,19 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
modelcard: Optional[ModelCard] = None,
|
modelcard: Optional[ModelCard] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = TokenClassificationArgumentHandler(),
|
||||||
device: int = -1,
|
device: int = -1,
|
||||||
binary_output: bool = False,
|
binary_output: bool = False,
|
||||||
ignore_labels=["O"],
|
ignore_labels=["O"],
|
||||||
task: str = "",
|
task: str = "",
|
||||||
grouped_entities: bool = False,
|
grouped_entities: bool = False,
|
||||||
ignore_subwords: bool = True,
|
ignore_subwords: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
modelcard=modelcard,
|
modelcard=modelcard,
|
||||||
framework=framework,
|
framework=framework,
|
||||||
args_parser=TokenClassificationArgumentHandler(),
|
|
||||||
device=device,
|
device=device,
|
||||||
binary_output=binary_output,
|
binary_output=binary_output,
|
||||||
task=task,
|
task=task,
|
||||||
@@ -1405,10 +1403,17 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||||
|
self._args_parser = args_parser
|
||||||
self.ignore_labels = ignore_labels
|
self.ignore_labels = ignore_labels
|
||||||
self.grouped_entities = grouped_entities
|
self.grouped_entities = grouped_entities
|
||||||
self.ignore_subwords = ignore_subwords
|
self.ignore_subwords = ignore_subwords
|
||||||
|
|
||||||
|
if self.ignore_subwords and not self.tokenizer.is_fast:
|
||||||
|
raise ValueError(
|
||||||
|
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option"
|
||||||
|
"to `False` or use a fast tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Classify each token of the text(s) given as inputs.
|
Classify each token of the text(s) given as inputs.
|
||||||
@@ -1429,10 +1434,7 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
corresponding token in the sentence.
|
corresponding token in the sentence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(inputs, str):
|
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||||
inputs = [inputs]
|
|
||||||
|
|
||||||
offset_mappings = kwargs.get("offset_mappings")
|
|
||||||
|
|
||||||
answers = []
|
answers = []
|
||||||
|
|
||||||
@@ -1450,14 +1452,13 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
return_offsets_mapping=self.tokenizer.is_fast,
|
return_offsets_mapping=self.tokenizer.is_fast,
|
||||||
)
|
)
|
||||||
if self.tokenizer.is_fast:
|
if self.tokenizer.is_fast:
|
||||||
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0]
|
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
|
||||||
del tokens["offset_mapping"]
|
|
||||||
elif offset_mappings:
|
elif offset_mappings:
|
||||||
offset_mapping = offset_mappings[i]
|
offset_mapping = offset_mappings[i]
|
||||||
else:
|
else:
|
||||||
raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter")
|
offset_mapping = None
|
||||||
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0]
|
|
||||||
del tokens["special_tokens_mask"]
|
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
@@ -1482,14 +1483,17 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for idx, label_idx in filtered_labels_idx:
|
for idx, label_idx in filtered_labels_idx:
|
||||||
start_ind, end_ind = offset_mapping[idx]
|
if offset_mapping is not None:
|
||||||
word_ref = sentence[start_ind:end_ind]
|
start_ind, end_ind = offset_mapping[idx]
|
||||||
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
|
word_ref = sentence[start_ind:end_ind]
|
||||||
is_subword = len(word_ref) != len(word)
|
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
|
||||||
|
is_subword = len(word_ref) != len(word)
|
||||||
|
|
||||||
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
||||||
word = word_ref
|
word = word_ref
|
||||||
is_subword = False
|
is_subword = False
|
||||||
|
else:
|
||||||
|
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
|
||||||
|
|
||||||
entity = {
|
entity = {
|
||||||
"word": word,
|
"word": word,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, pipeline
|
from transformers import AutoTokenizer, pipeline
|
||||||
from transformers.pipelines import Pipeline
|
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
||||||
from transformers.testing_utils import require_tf, require_torch
|
from transformers.testing_utils import require_tf, require_torch
|
||||||
|
|
||||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||||
@@ -107,13 +107,9 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
def test_tf_only(self):
|
def test_tf_only(self):
|
||||||
model_name = "Narsil/small" # This model only has a TensorFlow version
|
model_name = "Narsil/small" # This model only has a TensorFlow version
|
||||||
# We test that if we don't specificy framework='tf', it gets detected automatically
|
# We test that if we don't specificy framework='tf', it gets detected automatically
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
nlp = pipeline(task="ner", model=model_name)
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
|
|
||||||
self._test_pipeline(nlp)
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
# offset=tokenizer(VALID_INPUTS[0],return_offsets_mapping=True)['offset_mapping']
|
|
||||||
# pipeline_running_kwargs = {"offset_mapping"} # Additional kwargs to run the pipeline with
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tf_defaults(self):
|
def test_tf_defaults(self):
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
@@ -122,9 +118,8 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self._test_pipeline(nlp)
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tf_small(self):
|
def test_tf_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
print(model_name)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
nlp = pipeline(
|
nlp = pipeline(
|
||||||
task="ner",
|
task="ner",
|
||||||
@@ -136,27 +131,41 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self._test_pipeline(nlp)
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
for model_name in self.small_models:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
nlp = pipeline(
|
|
||||||
task="ner",
|
|
||||||
model=model_name,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
framework="tf",
|
|
||||||
grouped_entities=True,
|
|
||||||
ignore_subwords=False,
|
|
||||||
)
|
|
||||||
self._test_pipeline(nlp)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_pt_defaults(self):
|
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
nlp = pipeline(
|
||||||
|
task="ner",
|
||||||
|
model=model_name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
framework="tf",
|
||||||
|
grouped_entities=True,
|
||||||
|
ignore_subwords=False,
|
||||||
|
)
|
||||||
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||||
|
for model_name in self.small_models:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pt_defaults_slow_tokenizer(self):
|
||||||
|
for model_name in self.small_models:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
|
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
|
||||||
self._test_pipeline(nlp)
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_small(self):
|
def test_pt_defaults(self):
|
||||||
|
for model_name in self.small_models:
|
||||||
|
nlp = pipeline(task="ner", model=model_name)
|
||||||
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
nlp = pipeline(
|
nlp = pipeline(
|
||||||
@@ -170,3 +179,46 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
|
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
|
||||||
)
|
)
|
||||||
self._test_pipeline(nlp)
|
self._test_pipeline(nlp)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.args_parser = TokenClassificationArgumentHandler()
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
string = "This is a simple input"
|
||||||
|
|
||||||
|
inputs, offset_mapping = self.args_parser(string)
|
||||||
|
self.assertEqual(inputs, [string])
|
||||||
|
self.assertEqual(offset_mapping, None)
|
||||||
|
|
||||||
|
inputs, offset_mapping = self.args_parser(string, string)
|
||||||
|
self.assertEqual(inputs, [string, string])
|
||||||
|
self.assertEqual(offset_mapping, None)
|
||||||
|
|
||||||
|
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
|
||||||
|
self.assertEqual(inputs, [string])
|
||||||
|
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
||||||
|
|
||||||
|
inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||||
|
self.assertEqual(inputs, [string, string])
|
||||||
|
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||||
|
|
||||||
|
def test_errors(self):
|
||||||
|
string = "This is a simple input"
|
||||||
|
|
||||||
|
# 2 sentences, 1 offset_mapping
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
|
||||||
|
|
||||||
|
# 2 sentences, 1 offset_mapping
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
|
||||||
|
|
||||||
|
# 1 sentences, 2 offset_mapping
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
|
||||||
|
|
||||||
|
# 0 sentences, 1 offset_mapping
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])
|
||||||
|
|||||||
Reference in New Issue
Block a user