GPT-Neo ONNX export (#12911)
GPT-Neo ONNX export and task / feature refactoring Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -43,7 +43,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_gpt_neo import (
|
from .modeling_gpt_neo import (
|
||||||
|
|||||||
@@ -14,7 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" GPT Neo model configuration """
|
""" GPT Neo model configuration """
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -173,3 +178,140 @@ class GPTNeoConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def num_hidden_layers(self):
|
def num_hidden_layers(self):
|
||||||
return self.num_layers
|
return self.num_layers
|
||||||
|
|
||||||
|
|
||||||
|
def custom_unfold(input, dimension, size, step):
|
||||||
|
"""Custom torch.Tensor.unfold implementation to enable the export to ONNX."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
shape = input.size()
|
||||||
|
rank = len(shape)
|
||||||
|
sizedim = shape[dimension]
|
||||||
|
|
||||||
|
low_indices = torch.arange(0, sizedim, step)
|
||||||
|
min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1
|
||||||
|
indices = torch.arange(size) + low_indices[:min_length][:, None]
|
||||||
|
|
||||||
|
s = [slice(None)] * rank
|
||||||
|
s[dimension] = indices
|
||||||
|
sliced = input[s]
|
||||||
|
|
||||||
|
perm = list(range(0, rank + 1))
|
||||||
|
perm.append(perm.pop(dimension + 1))
|
||||||
|
|
||||||
|
return sliced.permute(perm)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_block_length_and_num_blocks(seq_length, window_size):
|
||||||
|
"""
|
||||||
|
Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as
|
||||||
|
original implmentation uses Python variables and control flow.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
candidates = torch.arange(1, window_size)
|
||||||
|
remainders = torch.remainder(seq_length, candidates)
|
||||||
|
divisor_indices = remainders == 0
|
||||||
|
divisors = candidates[divisor_indices]
|
||||||
|
largest_divisor = torch.max(divisors)
|
||||||
|
return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||||
|
def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False):
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .modeling_gpt_neo import GPTNeoAttentionMixin
|
||||||
|
|
||||||
|
patching_specs = [
|
||||||
|
PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold),
|
||||||
|
PatchingSpec(
|
||||||
|
GPTNeoAttentionMixin,
|
||||||
|
name="_get_block_length_and_num_blocks",
|
||||||
|
custom_op=custom_get_block_length_and_num_blocks,
|
||||||
|
op_wrapper=staticmethod,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||||
|
|
||||||
|
self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"])
|
||||||
|
self._key_values_dynamic_axis = []
|
||||||
|
for i in range(self._config.num_layers):
|
||||||
|
if self._config.attention_layers[i] == "local":
|
||||||
|
self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"})
|
||||||
|
else:
|
||||||
|
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||||
|
self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _number_key_values(self):
|
||||||
|
return (self._config.num_layers * 2) - self._num_local_attention
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._number_key_values):
|
||||||
|
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
|
||||||
|
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_outputs = super().outputs
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._number_key_values):
|
||||||
|
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
|
||||||
|
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||||
|
|
||||||
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||||
|
|
||||||
|
batch = common_inputs["input_ids"].shape[0]
|
||||||
|
past_shapes = {
|
||||||
|
"global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads),
|
||||||
|
"local": (batch, 1, self._config.hidden_size),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Need to add the past_keys
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ordered_inputs["past_key_values"] = []
|
||||||
|
for i in range(self._config.num_layers):
|
||||||
|
attention_type = self._config.attention_layers[i]
|
||||||
|
if attention_type == "global":
|
||||||
|
ordered_inputs["past_key_values"].append(
|
||||||
|
(
|
||||||
|
torch.zeros(past_shapes[attention_type]),
|
||||||
|
torch.zeros(past_shapes[attention_type]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),))
|
||||||
|
|
||||||
|
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||||
|
if self.use_past:
|
||||||
|
ordered_inputs["attention_mask"] = torch.cat(
|
||||||
|
[ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered_inputs
|
||||||
|
|||||||
@@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
|||||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -13,6 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
|
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec
|
||||||
from .convert import export, validate_model_outputs
|
from .convert import export, validate_model_outputs
|
||||||
from .utils import ParameterFormat, compute_serialized_parameters_size
|
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||||
|
|||||||
@@ -14,101 +14,22 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Tuple
|
|
||||||
|
|
||||||
from transformers.models.albert import AlbertOnnxConfig
|
|
||||||
from transformers.models.auto import AutoTokenizer
|
from transformers.models.auto import AutoTokenizer
|
||||||
from transformers.models.bart import BartOnnxConfig
|
|
||||||
from transformers.models.bert import BertOnnxConfig
|
|
||||||
from transformers.models.distilbert import DistilBertOnnxConfig
|
|
||||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
|
||||||
from transformers.models.longformer import LongformerOnnxConfig
|
|
||||||
from transformers.models.roberta import RobertaOnnxConfig
|
|
||||||
from transformers.models.t5 import T5OnnxConfig
|
|
||||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
|
||||||
|
|
||||||
from .. import is_torch_available
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .convert import export, validate_model_outputs
|
from .convert import export, validate_model_outputs
|
||||||
|
from .features import FeaturesManager
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
from transformers import AutoModel, PreTrainedModel
|
|
||||||
|
|
||||||
FEATURES_TO_AUTOMODELS = {
|
|
||||||
"default": AutoModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
|
||||||
SUPPORTED_MODEL_KIND = {
|
|
||||||
"albert": {"default": AlbertOnnxConfig.default},
|
|
||||||
"bart": {"default": BartOnnxConfig.default},
|
|
||||||
"bert": {"default": BertOnnxConfig.default},
|
|
||||||
"distilbert": {"default": DistilBertOnnxConfig.default},
|
|
||||||
"gpt2": {"default": GPT2OnnxConfig.default},
|
|
||||||
"longformer": {"default": LongformerOnnxConfig.default},
|
|
||||||
"roberta": {"default": RobertaOnnxConfig},
|
|
||||||
"t5": {"default": T5OnnxConfig.default},
|
|
||||||
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_from_features(features: str, model: str):
|
|
||||||
"""
|
|
||||||
Attempt to retrieve a model from a model's name and the features to be enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: The features required
|
|
||||||
model: The name of the model to export
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if features not in FEATURES_TO_AUTOMODELS:
|
|
||||||
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
|
|
||||||
|
|
||||||
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
|
|
||||||
|
|
||||||
|
|
||||||
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
|
|
||||||
"""
|
|
||||||
Check whether or not the model has the requested features
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to export
|
|
||||||
features: The name of the features to check if they are avaiable
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
|
||||||
|
|
||||||
"""
|
|
||||||
if model.config.model_type not in SUPPORTED_MODEL_KIND:
|
|
||||||
raise KeyError(
|
|
||||||
f"{model.config.model_type} ({model.name}) is not supported yet. "
|
|
||||||
f"Only {SUPPORTED_MODEL_KIND} are supported. "
|
|
||||||
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Look for the features
|
|
||||||
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
|
|
||||||
if features not in model_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"{model.config.model_type} doesn't support features {features}. "
|
|
||||||
f"Supported values are: {list(model_features.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
|
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
|
||||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
|
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--features",
|
"--feature",
|
||||||
choices=["default"],
|
choices=list(FeaturesManager.AVAILABLE_FEATURES),
|
||||||
default="default",
|
default="default",
|
||||||
help="Export the model with some additional features.",
|
help="Export the model with some additional feature.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
||||||
@@ -127,8 +48,8 @@ def main():
|
|||||||
|
|
||||||
# Allocate the model
|
# Allocate the model
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
model = get_model_from_features(args.features, args.model)
|
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
|
||||||
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||||
onnx_config = model_onnx_config(model.config)
|
onnx_config = model_onnx_config(model.config)
|
||||||
|
|
||||||
# Ensure the requested opset is sufficient
|
# Ensure the requested opset is sufficient
|
||||||
|
|||||||
@@ -11,9 +11,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Callable, List, Mapping, Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
@@ -26,6 +27,27 @@ DEFAULT_ONNX_OPSET = 11
|
|||||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PatchingSpec:
|
||||||
|
"""
|
||||||
|
Data class that holds patching specifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o: Module / object where the op to patch is located
|
||||||
|
name: Name of the op to monkey patch
|
||||||
|
custom_op: Custom op that patches the original op
|
||||||
|
orig_op: Original op that is being patched
|
||||||
|
op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
|
||||||
|
It is useful for ops that are class or static methods for instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
o: Any
|
||||||
|
name: str
|
||||||
|
custom_op: Callable
|
||||||
|
orig_op: Optional[Callable] = None
|
||||||
|
op_wrapper: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
class OnnxConfig(ABC):
|
class OnnxConfig(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
|
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
|
||||||
@@ -34,11 +56,38 @@ class OnnxConfig(ABC):
|
|||||||
DEFAULT_FIXED_BATCH = 2
|
DEFAULT_FIXED_BATCH = 2
|
||||||
DEFAULT_FIXED_SEQUENCE = 8
|
DEFAULT_FIXED_SEQUENCE = 8
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
_TASKS_TO_COMMON_OUTPUTS = {
|
||||||
|
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||||
|
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
|
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||||
|
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
|
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||||
|
"question-answering": OrderedDict(
|
||||||
|
{
|
||||||
|
"start_logits": {0: "batch", 1: "sequence"},
|
||||||
|
"end_logits": {0: "batch", 1: "sequence"},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
|
if task not in self._TASKS_TO_COMMON_OUTPUTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}"
|
||||||
|
)
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
self._patching_specs = []
|
||||||
|
for spec in patching_specs if patching_specs is not None else []:
|
||||||
|
final_spec = spec
|
||||||
|
if spec.orig_op is None:
|
||||||
|
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
|
||||||
|
self._patching_specs.append(final_spec)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
|
def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig":
|
||||||
"""
|
"""
|
||||||
Instantiate a OnnxConfig for a specific model
|
Instantiate a OnnxConfig for a specific model
|
||||||
|
|
||||||
@@ -48,7 +97,7 @@ class OnnxConfig(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
OnnxConfig for this model
|
OnnxConfig for this model
|
||||||
"""
|
"""
|
||||||
return cls(config)
|
return cls(config, task=task)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -62,7 +111,6 @@ class OnnxConfig(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
|
||||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
"""
|
"""
|
||||||
Mapping containing the axis definition of the output tensors to provide to the model
|
Mapping containing the axis definition of the output tensors to provide to the model
|
||||||
@@ -70,7 +118,7 @@ class OnnxConfig(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
return self._TASKS_TO_COMMON_OUTPUTS[self.task]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||||
@@ -170,14 +218,30 @@ class OnnxConfig(ABC):
|
|||||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
return dict(tokenizer(dummy_input, return_tensors=framework))
|
return dict(tokenizer(dummy_input, return_tensors=framework))
|
||||||
|
|
||||||
|
def patch_ops(self):
|
||||||
|
for spec in self._patching_specs:
|
||||||
|
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
|
||||||
|
setattr(spec.o, spec.name, custom_op)
|
||||||
|
|
||||||
|
def restore_ops(self):
|
||||||
|
for spec in self._patching_specs:
|
||||||
|
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||||
|
setattr(spec.o, spec.name, orig_op)
|
||||||
|
|
||||||
|
|
||||||
class OnnxConfigWithPast(OnnxConfig, ABC):
|
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
def __init__(
|
||||||
super().__init__(config)
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
task: str = "default",
|
||||||
|
patching_specs: List[PatchingSpec] = None,
|
||||||
|
use_past: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(config, task=task, patching_specs=patching_specs)
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
|
def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast":
|
||||||
"""
|
"""
|
||||||
Instantiate a OnnxConfig with `use_past` attribute set to True
|
Instantiate a OnnxConfig with `use_past` attribute set to True
|
||||||
|
|
||||||
@@ -187,7 +251,7 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
OnnxConfig with `.use_past = True`
|
OnnxConfig with `.use_past = True`
|
||||||
"""
|
"""
|
||||||
return cls(config, use_past=True)
|
return cls(config, task=task, use_past=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def values_override(self) -> Optional[Mapping[str, Any]]:
|
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||||
|
|||||||
@@ -111,6 +111,8 @@ def export(
|
|||||||
if not inputs_match:
|
if not inputs_match:
|
||||||
raise ValueError("Model and config inputs doesn't match")
|
raise ValueError("Model and config inputs doesn't match")
|
||||||
|
|
||||||
|
config.patch_ops()
|
||||||
|
|
||||||
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
@@ -125,6 +127,8 @@ def export(
|
|||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config.restore_ops()
|
||||||
|
|
||||||
return matched_inputs, onnx_outputs
|
return matched_inputs, onnx_outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -140,6 +144,8 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
logger.info("Validating ONNX model...")
|
logger.info("Validating ONNX model...")
|
||||||
|
|
||||||
|
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||||
|
# dynamic input shapes.
|
||||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||||
|
|
||||||
# Create ONNX Runtime session
|
# Create ONNX Runtime session
|
||||||
@@ -152,6 +158,10 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
||||||
for name, value in ref_outputs.items():
|
for name, value in ref_outputs.items():
|
||||||
|
# Overwriting the output name as "present" since it is the name used for the ONNX ouputs
|
||||||
|
# ("past_key_values" being taken for the ONNX inputs)
|
||||||
|
if name == "past_key_values":
|
||||||
|
name = "present"
|
||||||
if isinstance(value, (list, tuple)):
|
if isinstance(value, (list, tuple)):
|
||||||
value = flatten_output_collection_property(name, value)
|
value = flatten_output_collection_property(name, value)
|
||||||
ref_outputs_dict.update(value)
|
ref_outputs_dict.update(value)
|
||||||
@@ -186,7 +196,7 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
# Check the shape and values match
|
# Check the shape and values match
|
||||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||||
ref_value = ref_outputs_dict[name].numpy()
|
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||||
|
|
||||||
# Shape
|
# Shape
|
||||||
@@ -197,7 +207,7 @@ def validate_model_outputs(
|
|||||||
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
|
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
|
||||||
|
|
||||||
# Values
|
# Values
|
||||||
if not np.allclose(ref_value, ort_value, atol=atol):
|
if not np.allclose(ref_value, ort_value, atol=atol):
|
||||||
|
|||||||
135
src/transformers/onnx/features.py
Normal file
135
src/transformers/onnx/features.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
from functools import partial, reduce
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from .. import is_torch_available
|
||||||
|
from ..models.albert import AlbertOnnxConfig
|
||||||
|
from ..models.bart import BartOnnxConfig
|
||||||
|
from ..models.bert import BertOnnxConfig
|
||||||
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from ..models.longformer import LongformerOnnxConfig
|
||||||
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
|
from ..models.t5 import T5OnnxConfig
|
||||||
|
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.models.auto import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForMultipleChoice,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTokenClassification,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supported_features_mapping(*supported_features, onnx_config_cls=None):
|
||||||
|
"""Generates the mapping between supported features and their corresponding OnnxConfig."""
|
||||||
|
if onnx_config_cls is None:
|
||||||
|
raise ValueError("A OnnxConfig class must be provided")
|
||||||
|
|
||||||
|
mapping = {}
|
||||||
|
for feature in supported_features:
|
||||||
|
if "-with-past" in feature:
|
||||||
|
task = feature.replace("-with-past", "")
|
||||||
|
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
|
||||||
|
else:
|
||||||
|
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
class FeaturesManager:
|
||||||
|
_TASKS_TO_AUTOMODELS = {
|
||||||
|
"default": AutoModel,
|
||||||
|
"causal-lm": AutoModelForCausalLM,
|
||||||
|
"sequence-classification": AutoModelForSequenceClassification,
|
||||||
|
"token-classification": AutoModelForTokenClassification,
|
||||||
|
"multiple-choice": AutoModelForMultipleChoice,
|
||||||
|
"question-answering": AutoModelForQuestionAnswering,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||||
|
_SUPPORTED_MODEL_KIND = {
|
||||||
|
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
||||||
|
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||||
|
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||||
|
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||||
|
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||||
|
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||||
|
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||||
|
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
|
||||||
|
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||||
|
"gpt-neo": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"default-with-past",
|
||||||
|
"causal-lm-with-past",
|
||||||
|
"sequence-classification-with-past",
|
||||||
|
onnx_config_cls=GPTNeoOnnxConfig,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values())))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def feature_to_task(feature: str) -> str:
|
||||||
|
return feature.replace("-with-past", "")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_from_feature(feature: str, model: str):
|
||||||
|
"""
|
||||||
|
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature: The feature required
|
||||||
|
model: The name of the model to export
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
task = FeaturesManager.feature_to_task(feature)
|
||||||
|
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
|
||||||
|
raise KeyError(
|
||||||
|
f"Unknown task: {feature}."
|
||||||
|
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
||||||
|
"""
|
||||||
|
Check whether or not the model has the requested features
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to export
|
||||||
|
feature: The name of the feature to check if it is avaiable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_type = model.config.model_type.replace("_", "-")
|
||||||
|
model_name = getattr(model, "name", "")
|
||||||
|
model_name = f"({model_name})" if model_name else ""
|
||||||
|
if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND:
|
||||||
|
raise KeyError(
|
||||||
|
f"{model.config.model_type} ({model_name}) is not supported yet. "
|
||||||
|
f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. "
|
||||||
|
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look for the features
|
||||||
|
model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type]
|
||||||
|
if feature not in model_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model.config.model_type} doesn't support feature {feature}. "
|
||||||
|
f"Supported values are: {list(model_features.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature]
|
||||||
@@ -9,6 +9,7 @@ from transformers import ( # LongformerConfig,; T5Config,
|
|||||||
BartConfig,
|
BartConfig,
|
||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
|
GPTNeoConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -20,6 +21,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
|||||||
|
|
||||||
# from transformers.models.longformer import LongformerOnnxConfig
|
# from transformers.models.longformer import LongformerOnnxConfig
|
||||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||||
|
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||||
from transformers.models.roberta import RobertaOnnxConfig
|
from transformers.models.roberta import RobertaOnnxConfig
|
||||||
|
|
||||||
# from transformers.models.t5 import T5OnnxConfig
|
# from transformers.models.t5 import T5OnnxConfig
|
||||||
@@ -151,7 +153,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||||
with self.subTest(name):
|
with self.subTest(name):
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
|
OnnxConfigWithPast.from_model_config(config()).use_past,
|
||||||
|
"OnnxConfigWithPast.from_model_config() should not use_past",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -167,7 +170,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
with self.subTest(name):
|
with self.subTest(name):
|
||||||
|
|
||||||
# without past
|
# without past
|
||||||
onnx_config_default = OnnxConfigWithPast.default(config())
|
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||||
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
@@ -190,6 +193,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
DistilBertModel,
|
DistilBertModel,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
|
GPTNeoModel,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
XLMRobertaModel,
|
XLMRobertaModel,
|
||||||
)
|
)
|
||||||
@@ -200,6 +204,7 @@ if is_torch_available():
|
|||||||
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||||
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||||
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||||
|
("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
|
||||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||||
|
|||||||
Reference in New Issue
Block a user