Allow the creation of "entity groups" for NerPipeline #3548 (#3957)

* Add index to be returned by NerPipeline to allow for the creation of

* Add entity groups

* Convert entity list to dict

* Add entity to entity_group_disagg atfter updating entity gorups

* Change 'group' parameter to 'grouped_entities'

* Add unit tests for grouped NER pipeline case

* Correct variable name typo for NER_FINETUNED_MODELS

* Sync grouped tests to recent test updates
This commit is contained in:
Lorenzo Ampil
2020-05-17 15:25:17 +08:00
committed by GitHub
parent 3e0f062106
commit 18d233d525
2 changed files with 80 additions and 11 deletions

View File

@@ -160,6 +160,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
@require_torch
def test_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
@require_tf
def test_tf_ner(self):
mandatory_keys = {"entity", "word", "score"}
@@ -168,6 +176,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
@require_tf
def test_tf_ner_grouped(self):
mandatory_keys = {"entity_group", "word", "score"}
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
for model_name in NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys)
@require_torch
def test_torch_sentiment_analysis(self):
mandatory_keys = {"label", "score"}