Pipeline should be agnostic (#12656)
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_tf_available, is_torch_available
|
||||||
from transformers.data.processors.squad import SquadExample
|
from transformers.data.processors.squad import SquadExample
|
||||||
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
|
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
@@ -57,7 +58,7 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
task=self.pipeline_task,
|
task=self.pipeline_task,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=model,
|
tokenizer=model,
|
||||||
framework="pt",
|
framework="pt" if is_torch_available() else "tf",
|
||||||
**self.pipeline_loading_kwargs,
|
**self.pipeline_loading_kwargs,
|
||||||
)
|
)
|
||||||
for model in self.small_models
|
for model in self.small_models
|
||||||
@@ -65,6 +66,7 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
return question_answering_pipelines
|
return question_answering_pipelines
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@unittest.skipIf(not is_torch_available() and not is_tf_available(), "Either torch or TF must be installed.")
|
||||||
def test_high_topk_small_context(self):
|
def test_high_topk_small_context(self):
|
||||||
self.pipeline_running_kwargs.update({"topk": 20})
|
self.pipeline_running_kwargs.update({"topk": 20})
|
||||||
valid_inputs = [
|
valid_inputs = [
|
||||||
|
|||||||
Reference in New Issue
Block a user