GPT-Neo ONNX export (#12911)
GPT-Neo ONNX export and task / feature refactoring Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
@@ -43,7 +43,7 @@ if is_flax_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_gpt_neo import (
|
||||
|
||||
@@ -14,7 +14,12 @@
|
||||
# limitations under the License.
|
||||
""" GPT Neo model configuration """
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -173,3 +178,140 @@ class GPTNeoConfig(PretrainedConfig):
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.num_layers
|
||||
|
||||
|
||||
def custom_unfold(input, dimension, size, step):
|
||||
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
|
||||
import torch
|
||||
|
||||
shape = input.size()
|
||||
rank = len(shape)
|
||||
sizedim = shape[dimension]
|
||||
|
||||
low_indices = torch.arange(0, sizedim, step)
|
||||
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
|
||||
indices = torch.arange(size) + low_indices[:min_length][:, None]
|
||||
|
||||
s = [slice(None)] * rank
|
||||
s[dimension] = indices
|
||||
sliced = input[s]
|
||||
|
||||
perm = list(range(0, rank + 1))
|
||||
perm.append(perm.pop(dimension + 1))
|
||||
|
||||
return sliced.permute(perm)
|
||||
|
||||
|
||||
def custom_get_block_length_and_num_blocks(seq_length, window_size):
|
||||
"""
|
||||
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
|
||||
original implmentation uses Python variables and control flow.
|
||||
"""
|
||||
import torch
|
||||
|
||||
candidates = torch.arange(1, window_size)
|
||||
remainders = torch.remainder(seq_length, candidates)
|
||||
divisor_indices = remainders == 0
|
||||
divisors = candidates[divisor_indices]
|
||||
largest_divisor = torch.max(divisors)
|
||||
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
|
||||
|
||||
|
||||
class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from .modeling_gpt_neo import GPTNeoAttentionMixin
|
||||
|
||||
patching_specs = [
|
||||
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
|
||||
PatchingSpec(
|
||||
GPTNeoAttentionMixin,
|
||||
name="_get_block_length_and_num_blocks",
|
||||
custom_op=custom_get_block_length_and_num_blocks,
|
||||
op_wrapper=staticmethod,
|
||||
),
|
||||
]
|
||||
|
||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||
|
||||
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
|
||||
self._key_values_dynamic_axis = []
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
|
||||
else:
|
||||
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||
|
||||
@property
|
||||
def _number_key_values(self):
|
||||
return (self._config.num_layers * 2) - self._num_local_attention
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
for i in range(self._number_key_values):
|
||||
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
|
||||
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
for i in range(self._number_key_values):
|
||||
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
|
||||
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
batch = common_inputs["input_ids"].shape[0]
|
||||
past_shapes = {
|
||||
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
|
||||
"local": (batch, 1, self._config.hidden_size),
|
||||
}
|
||||
|
||||
# Need to add the past_keys
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
ordered_inputs["past_key_values"] = []
|
||||
for i in range(self._config.num_layers):
|
||||
attention_type = self._config.attention_layers[i]
|
||||
if attention_type == "global":
|
||||
ordered_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(past_shapes[attention_type]),
|
||||
torch.zeros(past_shapes[attention_type]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
@@ -13,6 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
|
||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
|
||||
from .convert import export, validate_model_outputs
|
||||
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||
|
||||
@@ -14,101 +14,22 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from transformers.models.albert import AlbertOnnxConfig
|
||||
from transformers.models.auto import AutoTokenizer
|
||||
from transformers.models.bart import BartOnnxConfig
|
||||
from transformers.models.bert import BertOnnxConfig
|
||||
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging
|
||||
from .convert import export, validate_model_outputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
|
||||
FEATURES_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
}
|
||||
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
SUPPORTED_MODEL_KIND = {
|
||||
"albert": {"default": AlbertOnnxConfig.default},
|
||||
"bart": {"default": BartOnnxConfig.default},
|
||||
"bert": {"default": BertOnnxConfig.default},
|
||||
"distilbert": {"default": DistilBertOnnxConfig.default},
|
||||
"gpt2": {"default": GPT2OnnxConfig.default},
|
||||
"longformer": {"default": LongformerOnnxConfig.default},
|
||||
"roberta": {"default": RobertaOnnxConfig},
|
||||
"t5": {"default": T5OnnxConfig.default},
|
||||
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
|
||||
}
|
||||
|
||||
|
||||
def get_model_from_features(features: str, model: str):
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the features to be enabled.
|
||||
|
||||
Args:
|
||||
features: The features required
|
||||
model: The name of the model to export
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if features not in FEATURES_TO_AUTOMODELS:
|
||||
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
|
||||
|
||||
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
|
||||
|
||||
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
features: The name of the features to check if they are avaiable
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||
|
||||
"""
|
||||
if model.config.model_type not in SUPPORTED_MODEL_KIND:
|
||||
raise KeyError(
|
||||
f"{model.config.model_type} ({model.name}) is not supported yet. "
|
||||
f"Only {SUPPORTED_MODEL_KIND} are supported. "
|
||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||
)
|
||||
|
||||
# Look for the features
|
||||
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
|
||||
if features not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support features {features}. "
|
||||
f"Supported values are: {list(model_features.keys())}"
|
||||
)
|
||||
|
||||
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
|
||||
from .features import FeaturesManager
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
|
||||
parser.add_argument(
|
||||
"--features",
|
||||
choices=["default"],
|
||||
"--feature",
|
||||
choices=list(FeaturesManager.AVAILABLE_FEATURES),
|
||||
default="default",
|
||||
help="Export the model with some additional features.",
|
||||
help="Export the model with some additional feature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
||||
@@ -127,8 +48,8 @@ def main():
|
||||
|
||||
# Allocate the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
model = get_model_from_features(args.features, args.model)
|
||||
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
|
||||
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
|
||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||
onnx_config = model_onnx_config(model.config)
|
||||
|
||||
# Ensure the requested opset is sufficient
|
||||
|
||||
@@ -11,9 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Callable, List, Mapping, Optional
|
||||
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
|
||||
@@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PatchingSpec:
|
||||
"""
|
||||
Data class that holds patching specifications.
|
||||
|
||||
Args:
|
||||
o: Module / object where the op to patch is located
|
||||
name: Name of the op to monkey patch
|
||||
custom_op: Custom op that patches the original op
|
||||
orig_op: Original op that is being patched
|
||||
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
|
||||
It is useful for ops that are class or static methods for instance.
|
||||
"""
|
||||
|
||||
o: Any
|
||||
name: str
|
||||
custom_op: Callable
|
||||
orig_op: Optional[Callable] = None
|
||||
op_wrapper: Optional[Callable] = None
|
||||
|
||||
|
||||
class OnnxConfig(ABC):
|
||||
"""
|
||||
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
|
||||
@@ -34,11 +56,38 @@ class OnnxConfig(ABC):
|
||||
DEFAULT_FIXED_BATCH = 2
|
||||
DEFAULT_FIXED_SEQUENCE = 8
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
_TASKS_TO_COMMON_OUTPUTS = {
|
||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||
"question-answering": OrderedDict(
|
||||
{
|
||||
"start_logits": {0: "batch", 1: "sequence"},
|
||||
"end_logits": {0: "batch", 1: "sequence"},
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
|
||||
self._config = config
|
||||
|
||||
if task not in self._TASKS_TO_COMMON_OUTPUTS:
|
||||
raise ValueError(
|
||||
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
|
||||
)
|
||||
self.task = task
|
||||
|
||||
self._patching_specs = []
|
||||
for spec in patching_specs if patching_specs is not None else []:
|
||||
final_spec = spec
|
||||
if spec.orig_op is None:
|
||||
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
|
||||
self._patching_specs.append(final_spec)
|
||||
|
||||
@classmethod
|
||||
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
|
||||
def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
|
||||
"""
|
||||
Instantiate a OnnxConfig for a specific model
|
||||
|
||||
@@ -48,7 +97,7 @@ class OnnxConfig(ABC):
|
||||
Returns:
|
||||
OnnxConfig for this model
|
||||
"""
|
||||
return cls(config)
|
||||
return cls(config, task=task)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -62,7 +111,6 @@ class OnnxConfig(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""
|
||||
Mapping containing the axis definition of the output tensors to provide to the model
|
||||
@@ -70,7 +118,7 @@ class OnnxConfig(ABC):
|
||||
Returns:
|
||||
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
@@ -170,14 +218,30 @@ class OnnxConfig(ABC):
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
return dict(tokenizer(dummy_input, return_tensors=framework))
|
||||
|
||||
def patch_ops(self):
|
||||
for spec in self._patching_specs:
|
||||
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
|
||||
setattr(spec.o, spec.name, custom_op)
|
||||
|
||||
def restore_ops(self):
|
||||
for spec in self._patching_specs:
|
||||
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||
setattr(spec.o, spec.name, orig_op)
|
||||
|
||||
|
||||
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs)
|
||||
self.use_past = use_past
|
||||
|
||||
@classmethod
|
||||
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
|
||||
def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
|
||||
"""
|
||||
Instantiate a OnnxConfig with `use_past` attribute set to True
|
||||
|
||||
@@ -187,7 +251,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
Returns:
|
||||
OnnxConfig with `.use_past = True`
|
||||
"""
|
||||
return cls(config, use_past=True)
|
||||
return cls(config, task=task, use_past=True)
|
||||
|
||||
@property
|
||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||
|
||||
@@ -111,6 +111,8 @@ def export(
|
||||
if not inputs_match:
|
||||
raise ValueError("Model and config inputs doesn't match")
|
||||
|
||||
config.patch_ops()
|
||||
|
||||
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
||||
export(
|
||||
model,
|
||||
@@ -125,6 +127,8 @@ def export(
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
config.restore_ops()
|
||||
|
||||
return matched_inputs, onnx_outputs
|
||||
|
||||
|
||||
@@ -140,6 +144,8 @@ def validate_model_outputs(
|
||||
|
||||
logger.info("Validating ONNX model...")
|
||||
|
||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# dynamic input shapes.
|
||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
|
||||
# Create ONNX Runtime session
|
||||
@@ -152,6 +158,10 @@ def validate_model_outputs(
|
||||
|
||||
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
||||
for name, value in ref_outputs.items():
|
||||
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
|
||||
# ("past_key_values" being taken for the ONNX inputs)
|
||||
if name == "past_key_values":
|
||||
name = "present"
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = flatten_output_collection_property(name, value)
|
||||
ref_outputs_dict.update(value)
|
||||
@@ -186,7 +196,7 @@ def validate_model_outputs(
|
||||
|
||||
# Check the shape and values match
|
||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||
ref_value = ref_outputs_dict[name].numpy()
|
||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||
|
||||
# Shape
|
||||
@@ -197,7 +207,7 @@ def validate_model_outputs(
|
||||
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
|
||||
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
|
||||
|
||||
# Values
|
||||
if not np.allclose(ref_value, ort_value, atol=atol):
|
||||
|
||||
135
src/transformers/onnx/features.py
Normal file
135
src/transformers/onnx/features.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..models.albert import AlbertOnnxConfig
|
||||
from ..models.bart import BartOnnxConfig
|
||||
from ..models.bert import BertOnnxConfig
|
||||
from ..models.distilbert import DistilBertOnnxConfig
|
||||
from ..models.gpt2 import GPT2OnnxConfig
|
||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||
from ..models.longformer import LongformerOnnxConfig
|
||||
from ..models.roberta import RobertaOnnxConfig
|
||||
from ..models.t5 import T5OnnxConfig
|
||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.auto import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
)
|
||||
|
||||
|
||||
def supported_features_mapping(*supported_features, onnx_config_cls=None):
|
||||
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
|
||||
if onnx_config_cls is None:
|
||||
raise ValueError("A OnnxConfig class must be provided")
|
||||
|
||||
mapping = {}
|
||||
for feature in supported_features:
|
||||
if "-with-past" in feature:
|
||||
task = feature.replace("-with-past", "")
|
||||
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
|
||||
else:
|
||||
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
"causal-lm": AutoModelForCausalLM,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"multiple-choice": AutoModelForMultipleChoice,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
}
|
||||
|
||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||
_SUPPORTED_MODEL_KIND = {
|
||||
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
||||
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
|
||||
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
"causal-lm",
|
||||
"sequence-classification",
|
||||
"default-with-past",
|
||||
"causal-lm-with-past",
|
||||
"sequence-classification-with-past",
|
||||
onnx_config_cls=GPTNeoOnnxConfig,
|
||||
),
|
||||
}
|
||||
|
||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
||||
|
||||
@staticmethod
|
||||
def feature_to_task(feature: str) -> str:
|
||||
return feature.replace("-with-past", "")
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_feature(feature: str, model: str):
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
Args:
|
||||
feature: The feature required
|
||||
model: The name of the model to export
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
task = FeaturesManager.feature_to_task(feature)
|
||||
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
|
||||
raise KeyError(
|
||||
f"Unknown task: {feature}."
|
||||
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||
)
|
||||
|
||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
|
||||
|
||||
@staticmethod
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
feature: The name of the feature to check if it is avaiable
|
||||
|
||||
Returns:
|
||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||
|
||||
"""
|
||||
model_type = model.config.model_type.replace("_", "-")
|
||||
model_name = getattr(model, "name", "")
|
||||
model_name = f"({model_name})" if model_name else ""
|
||||
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
|
||||
raise KeyError(
|
||||
f"{model.config.model_type} ({model_name}) is not supported yet. "
|
||||
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
|
||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||
)
|
||||
|
||||
# Look for the features
|
||||
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
|
||||
if feature not in model_features:
|
||||
raise ValueError(
|
||||
f"{model.config.model_type} doesn't support feature {feature}. "
|
||||
f"Supported values are: {list(model_features.keys())}"
|
||||
)
|
||||
|
||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
|
||||
@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config,
|
||||
BartConfig,
|
||||
DistilBertConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
RobertaConfig,
|
||||
XLMRobertaConfig,
|
||||
is_torch_available,
|
||||
@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
||||
|
||||
# from transformers.models.longformer import LongformerOnnxConfig
|
||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||
from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||
with self.subTest(name):
|
||||
self.assertFalse(
|
||||
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
|
||||
OnnxConfigWithPast.from_model_config(config()).use_past,
|
||||
"OnnxConfigWithPast.from_model_config() should not use_past",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
with self.subTest(name):
|
||||
|
||||
# without past
|
||||
onnx_config_default = OnnxConfigWithPast.default(config())
|
||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||
self.assertFalse(
|
||||
@@ -190,6 +193,7 @@ if is_torch_available():
|
||||
BertModel,
|
||||
DistilBertModel,
|
||||
GPT2Model,
|
||||
GPTNeoModel,
|
||||
RobertaModel,
|
||||
XLMRobertaModel,
|
||||
)
|
||||
@@ -200,6 +204,7 @@ if is_torch_available():
|
||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||
|
||||
Reference in New Issue
Block a user