clean pipelines (#3795)
This commit is contained in:
committed by
GitHub
parent
38f7461df3
commit
baca8fa8e6
@@ -2,26 +2,19 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import (
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
NerPipeline,
|
||||
Pipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
from transformers.pipelines import Pipeline
|
||||
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = [
|
||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
]
|
||||
|
||||
TF_QA_FINETUNED_MODELS = [
|
||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
]
|
||||
|
||||
TF_NER_FINETUNED_MODELS = {
|
||||
@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = (
|
||||
NerPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
FillMaskPipeline,
|
||||
TextClassificationPipeline,
|
||||
"ner",
|
||||
"feature-extraction",
|
||||
"question-answering",
|
||||
"fill-mask",
|
||||
"summarization",
|
||||
"sentiment-analysis",
|
||||
"translation_en_to_fr",
|
||||
"translation_en_to_de",
|
||||
"translation_en_to_ro",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tf_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="tf")
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="tf")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pt_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="pt")
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="pt")
|
||||
|
||||
Reference in New Issue
Block a user