diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 6fc1de1dcb..4a7bbeb77f 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -411,7 +411,8 @@ class TokenClassificationPipeline(Pipeline): tag = entity_name[2:] else: # It's not in B-, I- format - bi = "B" + # Default to I- for continuation. + bi = "I" tag = entity_name return bi, tag diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index b8b572e517..dcb4a2e535 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -318,6 +318,59 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ], ) + @require_torch + def test_aggregation_strategy_no_b_i_prefix(self): + model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + # Just to understand scores indexes in this test + token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"} + example = [ + { + # fmt : off + "scores": np.array([0, 0, 0, 0, 0.9968166351318359]), + "index": 1, + "is_subword": False, + "word": "En", + "start": 0, + "end": 2, + }, + { + # fmt : off + "scores": np.array([0, 0, 0, 0, 0.9957635998725891]), + "index": 2, + "is_subword": True, + "word": "##zo", + "start": 2, + "end": 4, + }, + { + # fmt: off + "scores": np.array([0, 0, 0, 0.9986497163772583, 0]), + # fmt: on + "index": 7, + "word": "UN", + "is_subword": False, + "start": 11, + "end": 13, + }, + ] + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)), + [ + {"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1}, + {"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2}, + {"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7}, + ], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)), + [ + {"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13}, + ], + ) + @require_torch def test_aggregation_strategy(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"