clean pipelines (#3795)
This commit is contained in:
committed by
GitHub
parent
38f7461df3
commit
baca8fa8e6
@@ -23,17 +23,12 @@ import sys
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||||
from .configuration_bart import BartConfig
|
|
||||||
from .configuration_distilbert import DistilBertConfig
|
|
||||||
from .configuration_roberta import RobertaConfig
|
|
||||||
from .configuration_t5 import T5Config
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .configuration_xlm import XLMConfig
|
|
||||||
from .data import SquadExample, squad_convert_examples_to_features
|
from .data import SquadExample, squad_convert_examples_to_features
|
||||||
from .file_utils import is_tf_available, is_torch_available
|
from .file_utils import is_tf_available, is_torch_available
|
||||||
from .modelcard import ModelCard
|
from .modelcard import ModelCard
|
||||||
@@ -423,27 +418,6 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||||
|
|
||||||
def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
|
|
||||||
"""
|
|
||||||
Generates the input dictionary with model-specific parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict holding all the required parameters for model's forward
|
|
||||||
"""
|
|
||||||
args = ["input_ids", "attention_mask"]
|
|
||||||
|
|
||||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
|
|
||||||
args += ["token_type_ids"]
|
|
||||||
|
|
||||||
# PR #1548 (CLI) There is an issue with attention_mask
|
|
||||||
# if 'xlnet' in model_type or 'xlm' in model_type:
|
|
||||||
# args += ['cls_index', 'p_mask']
|
|
||||||
|
|
||||||
if isinstance(features, dict):
|
|
||||||
return {k: features[k] for k in args}
|
|
||||||
else:
|
|
||||||
return {k: [feature[k] for feature in features] for k in args}
|
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
@@ -458,9 +432,6 @@ class Pipeline(_ScikitCompat):
|
|||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter out features not available on specific models
|
|
||||||
# inputs = self.inputs_for_model(inputs)
|
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def __call__(self, *texts, **kwargs):
|
def __call__(self, *texts, **kwargs):
|
||||||
@@ -995,7 +966,8 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
]
|
]
|
||||||
all_answers = []
|
all_answers = []
|
||||||
for features, example in zip(features_list, examples):
|
for features, example in zip(features_list, examples):
|
||||||
fw_args = self.inputs_for_model([f.__dict__ for f in features])
|
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
|
||||||
|
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
|
||||||
|
|
||||||
# Manage tensor allocation on correct device
|
# Manage tensor allocation on correct device
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
|
|||||||
@@ -2,26 +2,19 @@ import unittest
|
|||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.pipelines import (
|
from transformers.pipelines import Pipeline
|
||||||
FeatureExtractionPipeline,
|
|
||||||
FillMaskPipeline,
|
|
||||||
NerPipeline,
|
|
||||||
Pipeline,
|
|
||||||
QuestionAnsweringPipeline,
|
|
||||||
TextClassificationPipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import require_tf, require_torch, slow
|
from .utils import require_tf, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
QA_FINETUNED_MODELS = [
|
QA_FINETUNED_MODELS = [
|
||||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
TF_QA_FINETUNED_MODELS = [
|
TF_QA_FINETUNED_MODELS = [
|
||||||
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
(("bert-base-uncased", {"use_fast": False}), "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
||||||
(("bert-base-cased", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
(("distilbert-base-cased-distilled-squad", {"use_fast": False}), "distilbert-base-cased-distilled-squad", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
TF_NER_FINETUNED_MODELS = {
|
TF_NER_FINETUNED_MODELS = {
|
||||||
@@ -369,25 +362,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
|||||||
class PipelineCommonTests(unittest.TestCase):
|
class PipelineCommonTests(unittest.TestCase):
|
||||||
|
|
||||||
pipelines = (
|
pipelines = (
|
||||||
NerPipeline,
|
"ner",
|
||||||
FeatureExtractionPipeline,
|
"feature-extraction",
|
||||||
QuestionAnsweringPipeline,
|
"question-answering",
|
||||||
FillMaskPipeline,
|
"fill-mask",
|
||||||
TextClassificationPipeline,
|
"summarization",
|
||||||
|
"sentiment-analysis",
|
||||||
|
"translation_en_to_fr",
|
||||||
|
"translation_en_to_de",
|
||||||
|
"translation_en_to_ro",
|
||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tf_defaults(self):
|
def test_tf_defaults(self):
|
||||||
# Test that pipelines can be correctly loaded without any argument
|
# Test that pipelines can be correctly loaded without any argument
|
||||||
for default_pipeline in self.pipelines:
|
for task in self.pipelines:
|
||||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||||
default_pipeline(framework="tf")
|
pipeline(task, framework="tf")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pt_defaults(self):
|
def test_pt_defaults(self):
|
||||||
# Test that pipelines can be correctly loaded without any argument
|
# Test that pipelines can be correctly loaded without any argument
|
||||||
for default_pipeline in self.pipelines:
|
for task in self.pipelines:
|
||||||
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(default_pipeline.task)):
|
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
|
||||||
default_pipeline(framework="pt")
|
pipeline(task, framework="pt")
|
||||||
|
|||||||
Reference in New Issue
Block a user