Fixes NoneType exception when topk is larger than one coupled with a small context in the Question-Answering pipeline (#11628)

* added fix to decode function. added test to qa pipeline tests

* completed topk docstring

* fixed formatting with black

* applied style_doc to fix line length
This commit is contained in:
Pavel Soriano
2021-05-10 19:28:10 +02:00
committed by GitHub
parent dcb0e61430
commit 9120ae7d66
2 changed files with 46 additions and 6 deletions

View File

@@ -15,7 +15,8 @@
import unittest
from transformers.data.processors.squad import SquadExample
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
from transformers.testing_utils import slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
@@ -50,6 +51,34 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
},
]
def get_pipelines(self):
question_answering_pipelines = [
pipeline(
task=self.pipeline_task,
model=model,
tokenizer=model,
framework="pt",
**self.pipeline_loading_kwargs,
)
for model in self.small_models
]
return question_answering_pipelines
@slow
def test_high_topk_small_context(self):
self.pipeline_running_kwargs.update({"topk": 20})
valid_inputs = [
{"question": "Where was HuggingFace founded ?", "context": "Paris"},
]
nlps = self.get_pipelines()
output_keys = {"score", "answer", "start", "end"}
for nlp in nlps:
result = nlp(valid_inputs, **self.pipeline_running_kwargs)
self.assertIsInstance(result, dict)
for key in output_keys:
self.assertIn(key, result)
def _test_pipeline(self, nlp: Pipeline):
output_keys = {"score", "answer", "start", "end"}
valid_inputs = [