Adding some quality of life for pipeline function. (#14322)
* Adding some quality of life for `pipeline` function. * Update docs/source/main_classes/pipelines.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Improve the tests. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -29,8 +29,10 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch
|
||||
|
||||
@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
for output in text_classifier(dataset):
|
||||
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
|
||||
|
||||
@require_torch
|
||||
def test_check_task_auto_inference(self):
|
||||
pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification")
|
||||
|
||||
self.assertIsInstance(pipe, TextClassificationPipeline)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_override(self):
|
||||
class MyPipeline(TextClassificationPipeline):
|
||||
pass
|
||||
|
||||
text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline)
|
||||
|
||||
self.assertIsInstance(text_classifier, MyPipeline)
|
||||
|
||||
def test_check_task(self):
|
||||
task = get_task("gpt2")
|
||||
self.assertEqual(task, "text-generation")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Wrong framework
|
||||
get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best")
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user