Pipeline doc (#3055)
* Pipeline doc initial commit * pipeline abstraction * Remove modelcard argument from pipeline * Task-specific pipelines can be instantiated with no model or tokenizer * All pipelines doc
This commit is contained in:
@@ -2,9 +2,16 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.pipelines import (
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
NerPipeline,
|
||||
Pipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
|
||||
from .utils import require_tf, require_torch
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = [
|
||||
@@ -304,3 +311,30 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
||||
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = (
|
||||
NerPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
FillMaskPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tf_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="tf")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pt_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="pt")
|
||||
|
||||
Reference in New Issue
Block a user