clean pipelines (#3795)

This commit is contained in:
Patrick von Platen
2020-04-16 16:21:34 +02:00
committed by GitHub
parent 38f7461df3
commit baca8fa8e6
2 changed files with 21 additions and 52 deletions

View File

@@ -2,26 +2,19 @@ import unittest
from typing import Iterable, List, Optional
from transformers import pipeline
from transformers.pipelines import (
FeatureExtractionPipeline,
FillMaskPipeline,
NerPipeline,
Pipeline,
QuestionAnsweringPipeline,
TextClassificationPipeline,
)
from transformers.pipelines import Pipeline
from .utils import require_tf, require_torch, slow
QA_FINETUNED_MODELS = [
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
]
TF_QA_FINETUNED_MODELS = [
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
]
TF_NER_FINETUNED_MODELS = {
@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
class PipelineCommonTests(unittest.TestCase):
pipelines = (
NerPipeline,
FeatureExtractionPipeline,
QuestionAnsweringPipeline,
FillMaskPipeline,
TextClassificationPipeline,
"ner",
"feature-extraction",
"question-answering",
"fill-mask",
"summarization",
"sentiment-analysis",
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
)
@slow
@require_tf
def test_tf_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="tf")
for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="tf")
@slow
@require_torch
def test_pt_defaults(self):
# Test that pipelines can be correctly loaded without any argument
for default_pipeline in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
default_pipeline(framework="pt")
for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="pt")