* Add index to be returned by NerPipeline to allow for the creation of * Add entity groups * Convert entity list to dict * Add entity to entity_group_disagg atfter updating entity gorups * Change 'group' parameter to 'grouped_entities' * Add unit tests for grouped NER pipeline case * Correct variable name typo for NER_FINETUNED_MODELS * Sync grouped tests to recent test updates
This commit is contained in:
@@ -868,6 +868,7 @@ class NerPipeline(Pipeline):
|
|||||||
binary_output: bool = False,
|
binary_output: bool = False,
|
||||||
ignore_labels=["O"],
|
ignore_labels=["O"],
|
||||||
task: str = "",
|
task: str = "",
|
||||||
|
grouped_entities: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -882,6 +883,7 @@ class NerPipeline(Pipeline):
|
|||||||
|
|
||||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||||
self.ignore_labels = ignore_labels
|
self.ignore_labels = ignore_labels
|
||||||
|
self.grouped_entities = grouped_entities
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
inputs = self._args_parser(*args, **kwargs)
|
||||||
@@ -911,23 +913,74 @@ class NerPipeline(Pipeline):
|
|||||||
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
|
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
|
||||||
labels_idx = score.argmax(axis=-1)
|
labels_idx = score.argmax(axis=-1)
|
||||||
|
|
||||||
answer = []
|
entities = []
|
||||||
for idx, label_idx in enumerate(labels_idx):
|
entity_groups = []
|
||||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
entity_group_disagg = []
|
||||||
answer += [
|
# Filter to labels not in `self.ignore_labels`
|
||||||
{
|
filtered_labels_idx = [
|
||||||
|
(idx, label_idx)
|
||||||
|
for idx, label_idx in enumerate(labels_idx)
|
||||||
|
if self.model.config.id2label[label_idx] not in self.ignore_labels
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx, label_idx in filtered_labels_idx:
|
||||||
|
|
||||||
|
entity = {
|
||||||
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
|
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
|
||||||
"score": score[idx][label_idx].item(),
|
"score": score[idx][label_idx].item(),
|
||||||
"entity": self.model.config.id2label[label_idx],
|
"entity": self.model.config.id2label[label_idx],
|
||||||
|
"index": idx,
|
||||||
}
|
}
|
||||||
]
|
last_idx, _ = filtered_labels_idx[-1]
|
||||||
|
if self.grouped_entities:
|
||||||
|
if not entity_group_disagg:
|
||||||
|
entity_group_disagg += [entity]
|
||||||
|
if idx == last_idx:
|
||||||
|
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
|
||||||
|
if (
|
||||||
|
entity["entity"] == entity_group_disagg[-1]["entity"]
|
||||||
|
and entity["index"] == entity_group_disagg[-1]["index"] + 1
|
||||||
|
):
|
||||||
|
entity_group_disagg += [entity]
|
||||||
|
# Group the entities at the last entity
|
||||||
|
if idx == last_idx:
|
||||||
|
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||||
|
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
|
||||||
|
else:
|
||||||
|
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||||
|
entity_group_disagg = [entity]
|
||||||
|
|
||||||
|
entities += [entity]
|
||||||
|
|
||||||
# Append
|
# Append
|
||||||
answers += [answer]
|
if self.grouped_entities:
|
||||||
|
answers += [entity_groups]
|
||||||
|
else:
|
||||||
|
answers += [entities]
|
||||||
|
|
||||||
if len(answers) == 1:
|
if len(answers) == 1:
|
||||||
return answers[0]
|
return answers[0]
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
|
def group_entities(self, entities):
|
||||||
|
"""
|
||||||
|
Returns grouped entities
|
||||||
|
"""
|
||||||
|
# Get the last entity in the entity group
|
||||||
|
entity = entities[-1]["entity"]
|
||||||
|
scores = np.mean([entity["score"] for entity in entities])
|
||||||
|
tokens = [entity["word"] for entity in entities]
|
||||||
|
|
||||||
|
entity_group = {
|
||||||
|
"entity_group": entity,
|
||||||
|
"score": np.mean(scores),
|
||||||
|
"word": self.tokenizer.convert_tokens_to_string(tokens),
|
||||||
|
}
|
||||||
|
return entity_group
|
||||||
|
|
||||||
|
|
||||||
TokenClassificationPipeline = NerPipeline
|
TokenClassificationPipeline = NerPipeline
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
||||||
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_ner_grouped(self):
|
||||||
|
mandatory_keys = {"entity_group", "word", "score"}
|
||||||
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
|
||||||
|
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tf_ner(self):
|
def test_tf_ner(self):
|
||||||
mandatory_keys = {"entity", "word", "score"}
|
mandatory_keys = {"entity", "word", "score"}
|
||||||
@@ -168,6 +176,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_ner_grouped(self):
|
||||||
|
mandatory_keys = {"entity_group", "word", "score"}
|
||||||
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
||||||
|
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_sentiment_analysis(self):
|
def test_torch_sentiment_analysis(self):
|
||||||
mandatory_keys = {"label", "score"}
|
mandatory_keys = {"label", "score"}
|
||||||
|
|||||||
Reference in New Issue
Block a user