* Add index to be returned by NerPipeline to allow for the creation of * Add entity groups * Convert entity list to dict * Add entity to entity_group_disagg atfter updating entity gorups * Change 'group' parameter to 'grouped_entities' * Add unit tests for grouped NER pipeline case * Correct variable name typo for NER_FINETUNED_MODELS * Sync grouped tests to recent test updates
This commit is contained in:
@@ -868,6 +868,7 @@ class NerPipeline(Pipeline):
|
||||
binary_output: bool = False,
|
||||
ignore_labels=["O"],
|
||||
task: str = "",
|
||||
grouped_entities: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
@@ -882,6 +883,7 @@ class NerPipeline(Pipeline):
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
self.ignore_labels = ignore_labels
|
||||
self.grouped_entities = grouped_entities
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
@@ -911,23 +913,74 @@ class NerPipeline(Pipeline):
|
||||
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
|
||||
labels_idx = score.argmax(axis=-1)
|
||||
|
||||
answer = []
|
||||
for idx, label_idx in enumerate(labels_idx):
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
||||
answer += [
|
||||
{
|
||||
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
|
||||
"score": score[idx][label_idx].item(),
|
||||
"entity": self.model.config.id2label[label_idx],
|
||||
}
|
||||
]
|
||||
entities = []
|
||||
entity_groups = []
|
||||
entity_group_disagg = []
|
||||
# Filter to labels not in `self.ignore_labels`
|
||||
filtered_labels_idx = [
|
||||
(idx, label_idx)
|
||||
for idx, label_idx in enumerate(labels_idx)
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels
|
||||
]
|
||||
|
||||
for idx, label_idx in filtered_labels_idx:
|
||||
|
||||
entity = {
|
||||
"word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])),
|
||||
"score": score[idx][label_idx].item(),
|
||||
"entity": self.model.config.id2label[label_idx],
|
||||
"index": idx,
|
||||
}
|
||||
last_idx, _ = filtered_labels_idx[-1]
|
||||
if self.grouped_entities:
|
||||
if not entity_group_disagg:
|
||||
entity_group_disagg += [entity]
|
||||
if idx == last_idx:
|
||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||
continue
|
||||
|
||||
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
|
||||
if (
|
||||
entity["entity"] == entity_group_disagg[-1]["entity"]
|
||||
and entity["index"] == entity_group_disagg[-1]["index"] + 1
|
||||
):
|
||||
entity_group_disagg += [entity]
|
||||
# Group the entities at the last entity
|
||||
if idx == last_idx:
|
||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
|
||||
else:
|
||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
||||
entity_group_disagg = [entity]
|
||||
|
||||
entities += [entity]
|
||||
|
||||
# Append
|
||||
answers += [answer]
|
||||
if self.grouped_entities:
|
||||
answers += [entity_groups]
|
||||
else:
|
||||
answers += [entities]
|
||||
|
||||
if len(answers) == 1:
|
||||
return answers[0]
|
||||
return answers
|
||||
|
||||
def group_entities(self, entities):
|
||||
"""
|
||||
Returns grouped entities
|
||||
"""
|
||||
# Get the last entity in the entity group
|
||||
entity = entities[-1]["entity"]
|
||||
scores = np.mean([entity["score"] for entity in entities])
|
||||
tokens = [entity["word"] for entity in entities]
|
||||
|
||||
entity_group = {
|
||||
"entity_group": entity,
|
||||
"score": np.mean(scores),
|
||||
"word": self.tokenizer.convert_tokens_to_string(tokens),
|
||||
}
|
||||
return entity_group
|
||||
|
||||
|
||||
TokenClassificationPipeline = NerPipeline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user