[TokenClassification] Label realignment for subword aggregation (#11680)
* [TokenClassification] Label realignment for subword aggregation Tentative to replace https://github.com/huggingface/transformers/pull/11622/files - Added `AggregationStrategy` - `ignore_subwords` and `grouped_entities` arguments are now fused into `aggregation_strategy`. It makes more sense anyway because `ignore_subwords=True` with `grouped_entities=False` did not have a meaning anyway. - Added 2 new ways to aggregate which are MAX, and AVERAGE - AVERAGE requires a bit more information than the others, for now this case is slightly specific, we should keep that in mind for future changes. - Testing has been modified to reflect new argument, and to check the correct deprecation and the new aggregation_strategy. - Put the testing argument and testing results for aggregation_strategy, close together, so that readers can understand what is supposed to happen. - `aggregate` is now only tested on a small model as it does not mean anything to test it globally for all models. - Previous tests are unchanged in desired output. - Added a new test case that showcases better the difference between the FIRST, MAX and AVERAGE strategies. * Wrong framework. * Addressing three issues. 1- Tags might not follow B-, I- convention, so any tag should work now (assumed as B-TAG) 2- Fixed an issue with average that leads to a substantial code change. 3- The testing suite was not checking for the "index" key for "none" strategy. This is now fixed. The issue is that "O" could not be chosen by AVERAGE strategy because those tokens were filtered out beforehand, so their relative scores were not counted in the average. Now filtering on ignore_labels will happen at the very end of the pipeline fixing that issue. It's a bit hard to make sure this stays like that because we do not have a end-to-end test for that behavior * Formatting. * Adding formatting to code + cleaner handling of B-, I- tags. Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com> Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com> * Typo. Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com> Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com>
This commit is contained in:
@@ -14,16 +14,15 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, is_torch_available, pipeline
|
||||
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
||||
from transformers.pipelines import AggregationStrategy, Pipeline, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
|
||||
|
||||
|
||||
@@ -35,210 +34,10 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
large_models = [] # Models tested with the @slow decorator
|
||||
|
||||
def _test_pipeline(self, nlp: Pipeline):
|
||||
output_keys = {"entity", "word", "score", "start", "end"}
|
||||
if nlp.grouped_entities:
|
||||
output_keys = {"entity", "word", "score", "start", "end", "index"}
|
||||
if nlp.aggregation_strategy != AggregationStrategy.NONE:
|
||||
output_keys = {"entity_group", "word", "score", "start", "end"}
|
||||
|
||||
ungrouped_ner_inputs = [
|
||||
[
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 1,
|
||||
"score": 0.9994944930076599,
|
||||
"is_subword": False,
|
||||
"word": "Cons",
|
||||
"start": 0,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 2,
|
||||
"score": 0.8025449514389038,
|
||||
"is_subword": True,
|
||||
"word": "##uelo",
|
||||
"start": 4,
|
||||
"end": 8,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 3,
|
||||
"score": 0.9993102550506592,
|
||||
"is_subword": False,
|
||||
"word": "Ara",
|
||||
"start": 9,
|
||||
"end": 11,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 4,
|
||||
"score": 0.9993743896484375,
|
||||
"is_subword": True,
|
||||
"word": "##új",
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 5,
|
||||
"score": 0.9992871880531311,
|
||||
"is_subword": True,
|
||||
"word": "##o",
|
||||
"start": 13,
|
||||
"end": 14,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 6,
|
||||
"score": 0.9993029236793518,
|
||||
"is_subword": False,
|
||||
"word": "No",
|
||||
"start": 15,
|
||||
"end": 17,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 7,
|
||||
"score": 0.9981776475906372,
|
||||
"is_subword": True,
|
||||
"word": "##guera",
|
||||
"start": 17,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 15,
|
||||
"score": 0.9998136162757874,
|
||||
"is_subword": False,
|
||||
"word": "Andrés",
|
||||
"start": 23,
|
||||
"end": 28,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 16,
|
||||
"score": 0.999740719795227,
|
||||
"is_subword": False,
|
||||
"word": "Pas",
|
||||
"start": 29,
|
||||
"end": 32,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 17,
|
||||
"score": 0.9997414350509644,
|
||||
"is_subword": True,
|
||||
"word": "##tran",
|
||||
"start": 32,
|
||||
"end": 36,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 18,
|
||||
"score": 0.9996136426925659,
|
||||
"is_subword": True,
|
||||
"word": "##a",
|
||||
"start": 36,
|
||||
"end": 37,
|
||||
},
|
||||
{
|
||||
"entity": "B-ORG",
|
||||
"index": 28,
|
||||
"score": 0.9989739060401917,
|
||||
"is_subword": False,
|
||||
"word": "Far",
|
||||
"start": 39,
|
||||
"end": 42,
|
||||
},
|
||||
{
|
||||
"entity": "I-ORG",
|
||||
"index": 29,
|
||||
"score": 0.7188422083854675,
|
||||
"is_subword": True,
|
||||
"word": "##c",
|
||||
"start": 42,
|
||||
"end": 43,
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 1,
|
||||
"score": 0.9968166351318359,
|
||||
"is_subword": False,
|
||||
"word": "En",
|
||||
"start": 0,
|
||||
"end": 2,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 2,
|
||||
"score": 0.9957635998725891,
|
||||
"is_subword": True,
|
||||
"word": "##zo",
|
||||
"start": 2,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
"entity": "I-ORG",
|
||||
"index": 7,
|
||||
"score": 0.9986497163772583,
|
||||
"is_subword": False,
|
||||
"word": "UN",
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
expected_grouped_ner_results = [
|
||||
[
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.999369223912557,
|
||||
"word": "Consuelo Araújo Noguera",
|
||||
"start": 0,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9997771680355072,
|
||||
"word": "Andrés Pastrana",
|
||||
"start": 23,
|
||||
"end": 37,
|
||||
},
|
||||
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
]
|
||||
|
||||
expected_grouped_ner_results_w_subword = [
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9663328925768534,
|
||||
"word": "##uelo Araújo Noguera",
|
||||
"start": 4,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9997273534536362,
|
||||
"word": "Andrés Pastrana",
|
||||
"start": 23,
|
||||
"end": 37,
|
||||
},
|
||||
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
]
|
||||
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(VALID_INPUTS[0])
|
||||
@@ -262,15 +61,306 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
if nlp.grouped_entities:
|
||||
if nlp.ignore_subwords:
|
||||
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
|
||||
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
|
||||
else:
|
||||
for ungrouped_input, grouped_result in zip(
|
||||
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword
|
||||
):
|
||||
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
|
||||
@require_torch
|
||||
@slow
|
||||
def test_spanish_bert(self):
|
||||
# https://github.com/huggingface/transformers/pull/4987
|
||||
NER_MODEL = "mrm8488/bert-spanish-cased-finetuned-ner"
|
||||
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
|
||||
sentence = """Consuelo Araújo Noguera, ministra de cultura del presidente Andrés Pastrana (1998.2002) fue asesinada por las Farc luego de haber permanecido secuestrada por algunos meses."""
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity": "B-PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4, "index": 1},
|
||||
{"entity": "B-PER", "score": 0.803, "word": "##uelo", "start": 4, "end": 8, "index": 2},
|
||||
{"entity": "I-PER", "score": 0.999, "word": "Ara", "start": 9, "end": 12, "index": 3},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4},
|
||||
{"entity_group": "PER", "score": 0.966, "word": "##uelo Araújo Noguera", "start": 4, "end": 23},
|
||||
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
|
||||
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
|
||||
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.966, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
|
||||
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
|
||||
{"entity_group": "ORG", "score": 0.542, "word": "Farc", "start": 110, "end": 114},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_dbmdz_english(self):
|
||||
# Other sentence
|
||||
NER_MODEL = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
|
||||
sentence = """Enzo works at the the UN"""
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity": "I-PER", "score": 0.997, "word": "En", "start": 0, "end": 2, "index": 1},
|
||||
{"entity": "I-PER", "score": 0.996, "word": "##zo", "start": 2, "end": 4, "index": 2},
|
||||
{"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24, "index": 7},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output[:3]),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy(self):
|
||||
model_name = self.small_models[0]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
# Just to understand scores indexes in this test
|
||||
self.assertEqual(
|
||||
token_classifier.model.config.id2label,
|
||||
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
|
||||
)
|
||||
example = [
|
||||
{
|
||||
# fmt : off
|
||||
"scores": np.array([0, 0, 0, 0, 0.9968166351318359, 0, 0, 0]),
|
||||
"index": 1,
|
||||
"is_subword": False,
|
||||
"word": "En",
|
||||
"start": 0,
|
||||
"end": 2,
|
||||
},
|
||||
{
|
||||
# fmt : off
|
||||
"scores": np.array([0, 0, 0, 0, 0.9957635998725891, 0, 0, 0]),
|
||||
"index": 2,
|
||||
"is_subword": True,
|
||||
"word": "##zo",
|
||||
"start": 2,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
# fmt: off
|
||||
"scores": np.array([0, 0, 0, 0, 0, 0.9986497163772583, 0, 0, ]),
|
||||
# fmt: on
|
||||
"index": 7,
|
||||
"word": "UN",
|
||||
"is_subword": False,
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
]
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
|
||||
[
|
||||
{"end": 2, "entity": "I-PER", "score": 0.997, "start": 0, "word": "En", "index": 1},
|
||||
{"end": 4, "entity": "I-PER", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
|
||||
{"end": 13, "entity": "B-ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.FIRST)),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.MAX)),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_aggregation_strategy_example2(self):
|
||||
model_name = self.small_models[0]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
# Just to understand scores indexes in this test
|
||||
self.assertEqual(
|
||||
token_classifier.model.config.id2label,
|
||||
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
|
||||
)
|
||||
example = [
|
||||
{
|
||||
# Necessary for AVERAGE
|
||||
"scores": np.array([0, 0.55, 0, 0.45, 0, 0, 0, 0, 0, 0]),
|
||||
"is_subword": False,
|
||||
"index": 1,
|
||||
"word": "Ra",
|
||||
"start": 0,
|
||||
"end": 2,
|
||||
},
|
||||
{
|
||||
"scores": np.array([0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0]),
|
||||
"is_subword": True,
|
||||
"word": "##ma",
|
||||
"start": 2,
|
||||
"end": 4,
|
||||
"index": 2,
|
||||
},
|
||||
{
|
||||
# 4th score will have the higher average
|
||||
# 4th score is B-PER for this model
|
||||
# It's does not correspond to any of the subtokens.
|
||||
"scores": np.array([0, 0, 0, 0.4, 0, 0, 0.6, 0, 0, 0]),
|
||||
"is_subword": True,
|
||||
"word": "##zotti",
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
"index": 3,
|
||||
},
|
||||
]
|
||||
self.assertEqual(
|
||||
token_classifier.aggregate(example, AggregationStrategy.NONE),
|
||||
[
|
||||
{"end": 2, "entity": "B-MISC", "score": 0.55, "start": 0, "word": "Ra", "index": 1},
|
||||
{"end": 4, "entity": "B-LOC", "score": 0.8, "start": 2, "word": "##ma", "index": 2},
|
||||
{"end": 13, "entity": "I-ORG", "score": 0.6, "start": 11, "word": "##zotti", "index": 3},
|
||||
],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
token_classifier.aggregate(example, AggregationStrategy.FIRST),
|
||||
[{"entity_group": "MISC", "score": 0.55, "word": "Ramazotti", "start": 0, "end": 13}],
|
||||
)
|
||||
self.assertEqual(
|
||||
token_classifier.aggregate(example, AggregationStrategy.MAX),
|
||||
[{"entity_group": "LOC", "score": 0.8, "word": "Ramazotti", "start": 0, "end": 13}],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
|
||||
[{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_gather_pre_entities(self):
|
||||
|
||||
model_name = self.small_models[0]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
|
||||
|
||||
sentence = "Hello there"
|
||||
|
||||
tokens = tokenizer(
|
||||
sentence,
|
||||
return_attention_mask=False,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
return_special_tokens_mask=True,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
|
||||
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
|
||||
input_ids = tokens["input_ids"].numpy()[0]
|
||||
# First element in [CLS]
|
||||
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
|
||||
|
||||
pre_entities = nlp.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
|
||||
self.assertEqual(
|
||||
nested_simplify(pre_entities),
|
||||
[
|
||||
{"word": "Hello", "scores": [0.1, 0.3, 0.6], "start": 0, "end": 5, "is_subword": False, "index": 1},
|
||||
{
|
||||
"word": "there",
|
||||
"scores": [0.8, 0.1, 0.1],
|
||||
"index": 2,
|
||||
"start": 6,
|
||||
"end": 11,
|
||||
"is_subword": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_only(self):
|
||||
@@ -295,8 +385,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
framework="tf",
|
||||
grouped_entities=True,
|
||||
ignore_subwords=True,
|
||||
aggregation_strategy=AggregationStrategy.FIRST,
|
||||
)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
@@ -307,18 +396,23 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
framework="tf",
|
||||
grouped_entities=True,
|
||||
ignore_subwords=False,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE,
|
||||
)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
@require_torch
|
||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
model_name = self.small_models[0]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False)
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST)
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(
|
||||
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.AVERAGE
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.MAX)
|
||||
|
||||
@require_torch
|
||||
def test_pt_defaults_slow_tokenizer(self):
|
||||
@@ -333,27 +427,27 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
nlp = pipeline(task="ner", model=model_name)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_warnings(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(task="ner", model=self.small_models[0], grouped_entities=True)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
|
||||
with self.assertWarns(UserWarning):
|
||||
token_classifier = pipeline(
|
||||
task="ner", model=self.small_models[0], grouped_entities=True, ignore_subwords=True
|
||||
)
|
||||
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple(self):
|
||||
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
|
||||
nlp = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy=AggregationStrategy.SIMPLE)
|
||||
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
|
||||
sentence2 = "This is a simple test"
|
||||
output = nlp(sentence)
|
||||
|
||||
def simplify(output):
|
||||
if isinstance(output, (list, tuple)):
|
||||
return [simplify(item) for item in output]
|
||||
elif isinstance(output, dict):
|
||||
return {simplify(k): simplify(v) for k, v in output.items()}
|
||||
elif isinstance(output, (str, int, np.int64)):
|
||||
return output
|
||||
elif isinstance(output, float):
|
||||
return round(output, 3)
|
||||
else:
|
||||
raise Exception(f"Cannot handle {type(output)}")
|
||||
|
||||
output_ = simplify(output)
|
||||
output_ = nested_simplify(output)
|
||||
|
||||
self.assertEqual(
|
||||
output_,
|
||||
@@ -371,7 +465,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
)
|
||||
|
||||
output = nlp([sentence, sentence2])
|
||||
output_ = simplify(output)
|
||||
output_ = nested_simplify(output)
|
||||
|
||||
self.assertEqual(
|
||||
output_,
|
||||
@@ -390,14 +484,14 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
nlp = pipeline(
|
||||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
|
||||
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST
|
||||
)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
for model_name in self.small_models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
nlp = pipeline(
|
||||
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
|
||||
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.SIMPLE
|
||||
)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user