clean pipelines (#3795)
This commit is contained in:
committed by
GitHub
parent
38f7461df3
commit
baca8fa8e6
@@ -23,17 +23,12 @@ import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from os.path import abspath, exists
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .configuration_t5 import T5Config
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .data import SquadExample, squad_convert_examples_to_features
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
from .modelcard import ModelCard
|
||||
@@ -423,27 +418,6 @@ class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||
|
||||
def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
|
||||
"""
|
||||
Generates the input dictionary with model-specific parameters.
|
||||
|
||||
Returns:
|
||||
dict holding all the required parameters for model's forward
|
||||
"""
|
||||
args = ["input_ids", "attention_mask"]
|
||||
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
|
||||
args += ["token_type_ids"]
|
||||
|
||||
# PR #1548 (CLI) There is an issue with attention_mask
|
||||
# if 'xlnet' in model_type or 'xlm' in model_type:
|
||||
# args += ['cls_index', 'p_mask']
|
||||
|
||||
if isinstance(features, dict):
|
||||
return {k: features[k] for k in args}
|
||||
else:
|
||||
return {k: [feature[k] for feature in features] for k in args}
|
||||
|
||||
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
@@ -458,9 +432,6 @@ class Pipeline(_ScikitCompat):
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
)
|
||||
|
||||
# Filter out features not available on specific models
|
||||
# inputs = self.inputs_for_model(inputs)
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
@@ -995,7 +966,8 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
]
|
||||
all_answers = []
|
||||
for features, example in zip(features_list, examples):
|
||||
fw_args = self.inputs_for_model([f.__dict__ for f in features])
|
||||
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
|
||||
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
|
||||
|
||||
# Manage tensor allocation on correct device
|
||||
with self.device_placement():
|
||||
|
||||
@@ -2,26 +2,19 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import (
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
NerPipeline,
|
||||
Pipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
TextClassificationPipeline,
|
||||
)
|
||||
from transformers.pipelines import Pipeline
|
||||
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = [
|
||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
]
|
||||
|
||||
TF_QA_FINETUNED_MODELS = [
|
||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||
]
|
||||
|
||||
TF_NER_FINETUNED_MODELS = {
|
||||
@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = (
|
||||
NerPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
QuestionAnsweringPipeline,
|
||||
FillMaskPipeline,
|
||||
TextClassificationPipeline,
|
||||
"ner",
|
||||
"feature-extraction",
|
||||
"question-answering",
|
||||
"fill-mask",
|
||||
"summarization",
|
||||
"sentiment-analysis",
|
||||
"translation_en_to_fr",
|
||||
"translation_en_to_de",
|
||||
"translation_en_to_ro",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tf_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="tf")
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="tf")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pt_defaults(self):
|
||||
# Test that pipelines can be correctly loaded without any argument
|
||||
for default_pipeline in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
||||
default_pipeline(framework="pt")
|
||||
for task in self.pipelines:
|
||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||
pipeline(task, framework="pt")
|
||||
|
||||
Reference in New Issue
Block a user