clean pipelines (#3795)
This commit is contained in:
committed by
GitHub
parent
38f7461df3
commit
baca8fa8e6
@@ -23,17 +23,12 @@ import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from os.path import abspath, exists
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .configuration_t5 import T5Config
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .data import SquadExample, squad_convert_examples_to_features
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
from .modelcard import ModelCard
|
||||
@@ -423,27 +418,6 @@ class Pipeline(_ScikitCompat):
|
||||
"""
|
||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||
|
||||
def inputs_for_model(self, features: Union[dict, List[dict]]) -> Dict:
|
||||
"""
|
||||
Generates the input dictionary with model-specific parameters.
|
||||
|
||||
Returns:
|
||||
dict holding all the required parameters for model's forward
|
||||
"""
|
||||
args = ["input_ids", "attention_mask"]
|
||||
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
|
||||
args += ["token_type_ids"]
|
||||
|
||||
# PR #1548 (CLI) There is an issue with attention_mask
|
||||
# if 'xlnet' in model_type or 'xlm' in model_type:
|
||||
# args += ['cls_index', 'p_mask']
|
||||
|
||||
if isinstance(features, dict):
|
||||
return {k: features[k] for k in args}
|
||||
else:
|
||||
return {k: [feature[k] for feature in features] for k in args}
|
||||
|
||||
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
@@ -458,9 +432,6 @@ class Pipeline(_ScikitCompat):
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
)
|
||||
|
||||
# Filter out features not available on specific models
|
||||
# inputs = self.inputs_for_model(inputs)
|
||||
|
||||
return inputs
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
@@ -995,7 +966,8 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
]
|
||||
all_answers = []
|
||||
for features, example in zip(features_list, examples):
|
||||
fw_args = self.inputs_for_model([f.__dict__ for f in features])
|
||||
model_input_names = self.tokenizer.model_input_names + ["input_ids"]
|
||||
fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names}
|
||||
|
||||
# Manage tensor allocation on correct device
|
||||
with self.device_placement():
|
||||
|
||||
Reference in New Issue
Block a user