From 900daec24ee7f2e90ac63b0d350e7f223a5bcded Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 Feb 2021 12:22:45 +0100 Subject: [PATCH] Fixing NER pipeline for list inputs. (#10184) Fixes #10168 --- .../pipelines/token_classification.py | 13 ++-- tests/test_pipelines_ner.py | 67 ++++++++++++++----- 2 files changed, 60 insertions(+), 20 deletions(-) diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 5dce3402cd..d9431c0cb7 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler): Handles arguments for token classification. """ - def __call__(self, *args, **kwargs): + def __call__(self, inputs: Union[str, List[str]], **kwargs): - if args is not None and len(args) > 0: - inputs = list(args) + if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: + inputs = list(inputs) batch_size = len(inputs) + elif isinstance(inputs, str): + inputs = [inputs] + batch_size = 1 else: raise ValueError("At least one input is required.") @@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline): Only exists if the offsets are available within the tokenizer """ - inputs, offset_mappings = self._args_parser(inputs, **kwargs) + _inputs, offset_mappings = self._args_parser(inputs, **kwargs) answers = [] - for i, sentence in enumerate(inputs): + for i, sentence in enumerate(_inputs): # Manage correct placement of the tensors with self.device_placement(): diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index 9e21456a63..c7b8171ef2 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -14,14 +14,17 @@ import unittest -from transformers import AutoTokenizer, pipeline +from transformers import AutoTokenizer, is_torch_available, pipeline from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler from transformers.testing_utils import require_tf, require_torch, slow from .test_pipelines_common import CustomInputPipelineCommonMixin -VALID_INPUTS = ["A simple string", ["list of strings"]] +if is_torch_available(): + import numpy as np + +VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): @@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): @require_torch def test_simple(self): nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True) - output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York") + sentence = "Hello Sarah Jessica Parker who Jessica lives in New York" + sentence2 = "This is a simple test" + output = nlp(sentence) def simplify(output): - for i in range(len(output)): - output[i]["score"] = round(output[i]["score"], 3) - return output + if isinstance(output, (list, tuple)): + return [simplify(item) for item in output] + elif isinstance(output, dict): + return {simplify(k): simplify(v) for k, v in output.items()} + elif isinstance(output, (str, int, np.int64)): + return output + elif isinstance(output, float): + return round(output, 3) + else: + raise Exception(f"Cannot handle {type(output)}") - output = simplify(output) + output_ = simplify(output) self.assertEqual( - output, + output_, [ { "entity_group": "PER", @@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): ], ) + output = nlp([sentence, sentence2]) + output_ = simplify(output) + + self.assertEqual( + output_, + [ + [ + {"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26}, + {"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38}, + {"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56}, + ], + [], + ], + ) + @require_torch def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self): for model_name in self.small_models: @@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): self.assertEqual(inputs, [string]) self.assertEqual(offset_mapping, None) - inputs, offset_mapping = self.args_parser(string, string) + inputs, offset_mapping = self.args_parser([string, string]) self.assertEqual(inputs, [string, string]) self.assertEqual(offset_mapping, None) @@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): self.assertEqual(inputs, [string]) self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]]) - inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) + inputs, offset_mapping = self.args_parser( + [string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]] + ) self.assertEqual(inputs, [string, string]) self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) def test_errors(self): string = "This is a simple input" - # 2 sentences, 1 offset_mapping - with self.assertRaises(ValueError): + # 2 sentences, 1 offset_mapping, args + with self.assertRaises(TypeError): self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]]) - # 2 sentences, 1 offset_mapping - with self.assertRaises(ValueError): + # 2 sentences, 1 offset_mapping, args + with self.assertRaises(TypeError): self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)]) + # 2 sentences, 1 offset_mapping, input_list + with self.assertRaises(ValueError): + self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]]) + + # 2 sentences, 1 offset_mapping, input_list + with self.assertRaises(ValueError): + self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)]) + # 1 sentences, 2 offset_mapping with self.assertRaises(ValueError): self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) # 0 sentences, 1 offset_mapping - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])