Fixes NoneType exception when topk is larger than one coupled with a small context in the Question-Answering pipeline (#11628)
* added fix to decode function. added test to qa pipeline tests * completed topk docstring * fixed formatting with black * applied style_doc to fix line length
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
import unittest
|
||||
|
||||
from transformers.data.processors.squad import SquadExample
|
||||
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler
|
||||
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
@@ -50,6 +51,34 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
def get_pipelines(self):
|
||||
question_answering_pipelines = [
|
||||
pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model,
|
||||
tokenizer=model,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
for model in self.small_models
|
||||
]
|
||||
return question_answering_pipelines
|
||||
|
||||
@slow
|
||||
def test_high_topk_small_context(self):
|
||||
self.pipeline_running_kwargs.update({"topk": 20})
|
||||
valid_inputs = [
|
||||
{"question": "Where was HuggingFace founded ?", "context": "Paris"},
|
||||
]
|
||||
nlps = self.get_pipelines()
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
for nlp in nlps:
|
||||
result = nlp(valid_inputs, **self.pipeline_running_kwargs)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
def _test_pipeline(self, nlp: Pipeline):
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
valid_inputs = [
|
||||
|
||||
Reference in New Issue
Block a user