@@ -96,7 +96,6 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
default_input_names = "sequences"
|
default_input_names = "sequences"
|
||||||
|
|
||||||
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
|
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
|
||||||
self.ignore_labels = ["O"]
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.check_model_type(
|
self.check_model_type(
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||||
@@ -216,7 +215,9 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
**model_inputs,
|
**model_inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
|
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
|
||||||
|
if ignore_labels is None:
|
||||||
|
ignore_labels = ["O"]
|
||||||
logits = model_outputs["logits"][0].numpy()
|
logits = model_outputs["logits"][0].numpy()
|
||||||
sentence = model_outputs["sentence"]
|
sentence = model_outputs["sentence"]
|
||||||
input_ids = model_outputs["input_ids"][0]
|
input_ids = model_outputs["input_ids"][0]
|
||||||
@@ -235,8 +236,8 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
entities = [
|
entities = [
|
||||||
entity
|
entity
|
||||||
for entity in grouped_entities
|
for entity in grouped_entities
|
||||||
if entity.get("entity", None) not in self.ignore_labels
|
if entity.get("entity", None) not in ignore_labels
|
||||||
and entity.get("entity_group", None) not in self.ignore_labels
|
and entity.get("entity_group", None) not in ignore_labels
|
||||||
]
|
]
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|||||||
@@ -627,6 +627,15 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_classifier = pipeline(
|
||||||
|
task="token-classification", model=model_name, framework="pt", ignore_labels=["O", "I-MISC"]
|
||||||
|
)
|
||||||
|
outputs = token_classifier("This is a test !")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||||
|
|||||||
Reference in New Issue
Block a user