T5 with past ONNX export (#13014)
T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -15,7 +15,7 @@
|
|||||||
""" GPT Neo model configuration """
|
""" GPT Neo model configuration """
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||||
|
|
||||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
@@ -253,8 +253,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
|||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
for i in range(self._number_key_values):
|
for i in range(self._config.num_layers):
|
||||||
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
|
if self._config.attention_layers[i] == "local":
|
||||||
|
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||||
|
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||||
|
|
||||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
|
||||||
@@ -264,9 +268,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
|||||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
common_outputs = super().outputs
|
common_outputs = super().outputs
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
for i in range(self._number_key_values):
|
for i in range(self._config.num_layers):
|
||||||
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
|
if self._config.attention_layers[i] == "local":
|
||||||
|
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||||
|
else:
|
||||||
|
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||||
|
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||||
return common_outputs
|
return common_outputs
|
||||||
|
|
||||||
def generate_dummy_inputs(
|
def generate_dummy_inputs(
|
||||||
@@ -315,3 +322,18 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ordered_inputs
|
return ordered_inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||||
|
if name in ["present", "past_key_values"]:
|
||||||
|
flatten_output = {}
|
||||||
|
for idx, t in enumerate(field):
|
||||||
|
if len(t) == 1:
|
||||||
|
flatten_output[f"{name}.{idx}.key_value"] = t[0]
|
||||||
|
else:
|
||||||
|
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||||
|
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||||
|
|
||||||
|
return flatten_output
|
||||||
|
|
||||||
|
return super().flatten_output_collection_property(name, field)
|
||||||
|
|||||||
@@ -14,10 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" T5 model configuration """
|
""" T5 model configuration """
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, TensorType
|
from transformers import PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
|
from ... import is_torch_available
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...onnx import OnnxConfigWithPast
|
from ...onnx import OnnxConfigWithPast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
class T5OnnxConfig(OnnxConfigWithPast):
|
class T5OnnxConfig(OnnxConfigWithPast):
|
||||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
|
||||||
super().__init__(config, use_past)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
common_inputs = OrderedDict(
|
common_inputs = OrderedDict(
|
||||||
@@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
for i in range(self._config.num_layers):
|
for i in range(0, self._config.num_layers):
|
||||||
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||||
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||||
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||||
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||||
|
|
||||||
return common_inputs
|
return common_inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
common_outputs = OrderedDict(
|
common_outputs = super().outputs
|
||||||
[
|
|
||||||
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
|
if "last_hidden_state" in common_outputs:
|
||||||
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
|
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
for i in range(self._config.num_layers):
|
for i in range(self._config.num_layers):
|
||||||
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
|
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
|
||||||
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
|
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
|
||||||
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
|
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
|
||||||
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
|
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
|
||||||
|
|
||||||
|
if self.task == "default":
|
||||||
|
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
|
||||||
|
|
||||||
return common_outputs
|
return common_outputs
|
||||||
|
|
||||||
@@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
|||||||
is_pair: bool = False,
|
is_pair: bool = False,
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
if self.use_past:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
# Generate encoder inputs
|
# Generate encoder inputs
|
||||||
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||||
@@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
|||||||
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
||||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
|
|
||||||
return dict(**encoder_inputs, **decoder_inputs)
|
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
batch = encoder_inputs["input_ids"].shape[0]
|
||||||
|
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
|
||||||
|
encoder_shape = (
|
||||||
|
batch,
|
||||||
|
self._config.num_heads,
|
||||||
|
encoder_seq_length,
|
||||||
|
self._config.hidden_size // self._config.num_heads,
|
||||||
|
)
|
||||||
|
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
|
||||||
|
|
||||||
|
ordered_inputs["past_key_values"] = []
|
||||||
|
for _ in range(self._config.num_layers):
|
||||||
|
ordered_inputs["past_key_values"].append(
|
||||||
|
(
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(decoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
torch.zeros(encoder_shape),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ordered_inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||||
|
if name in ["present", "past_key_values"]:
|
||||||
|
flatten_output = {}
|
||||||
|
for idx, t in enumerate(field):
|
||||||
|
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
|
||||||
|
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
|
||||||
|
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
|
||||||
|
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
|
||||||
|
|
||||||
|
return flatten_output
|
||||||
|
|
||||||
|
return super().flatten_output_collection_property(name, field)
|
||||||
|
|||||||
@@ -429,8 +429,6 @@ class T5Attention(nn.Module):
|
|||||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
|
||||||
int_seq_length = int(seq_length)
|
|
||||||
|
|
||||||
real_seq_length = seq_length
|
real_seq_length = seq_length
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@@ -499,7 +497,7 @@ class T5Attention(nn.Module):
|
|||||||
# if key and values are already calculated
|
# if key and values are already calculated
|
||||||
# we want only the last query position bias
|
# we want only the last query position bias
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
position_bias = position_bias[:, :, -int_seq_length:, :]
|
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||||
@@ -629,7 +627,7 @@ class T5Block(nn.Module):
|
|||||||
if len(past_key_value) != expected_num_past_key_values:
|
if len(past_key_value) != expected_num_past_key_values:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"There should be {expected_num_past_key_values} past states. "
|
f"There should be {expected_num_past_key_values} past states. "
|
||||||
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
|
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
||||||
f"Got {len(past_key_value)} past key / value states"
|
f"Got {len(past_key_value)} past key / value states"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, List, Mapping, Optional
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
@@ -59,6 +59,7 @@ class OnnxConfig(ABC):
|
|||||||
_TASKS_TO_COMMON_OUTPUTS = {
|
_TASKS_TO_COMMON_OUTPUTS = {
|
||||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
|
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||||
@@ -228,6 +229,24 @@ class OnnxConfig(ABC):
|
|||||||
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||||
setattr(spec.o, spec.name, orig_op)
|
setattr(spec.o, spec.name, orig_op)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||||
|
structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the nested structure
|
||||||
|
field: The structure to, potentially, be flattened
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||||
|
|
||||||
|
|
||||||
class OnnxConfigWithPast(OnnxConfig, ABC):
|
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -285,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
|||||||
# Generate dummy inputs according to compute batch and sequence
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||||
|
if name in ["present", "past_key_values"]:
|
||||||
|
flatten_output = {}
|
||||||
|
for idx, t in enumerate(field):
|
||||||
|
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||||
|
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||||
|
|
||||||
|
return flatten_output
|
||||||
|
|
||||||
|
return super().flatten_output_collection_property(name, field)
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedMod
|
|||||||
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .config import OnnxConfig
|
from .config import OnnxConfig
|
||||||
from .utils import flatten_output_collection_property
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -163,7 +162,7 @@ def validate_model_outputs(
|
|||||||
if name == "past_key_values":
|
if name == "past_key_values":
|
||||||
name = "present"
|
name = "present"
|
||||||
if isinstance(value, (list, tuple)):
|
if isinstance(value, (list, tuple)):
|
||||||
value = flatten_output_collection_property(name, value)
|
value = config.flatten_output_collection_property(name, value)
|
||||||
ref_outputs_dict.update(value)
|
ref_outputs_dict.update(value)
|
||||||
else:
|
else:
|
||||||
ref_outputs_dict[name] = value
|
ref_outputs_dict[name] = value
|
||||||
@@ -172,7 +171,7 @@ def validate_model_outputs(
|
|||||||
onnx_inputs = {}
|
onnx_inputs = {}
|
||||||
for name, value in reference_model_inputs.items():
|
for name, value in reference_model_inputs.items():
|
||||||
if isinstance(value, (list, tuple)):
|
if isinstance(value, (list, tuple)):
|
||||||
value = flatten_output_collection_property(name, value)
|
value = config.flatten_output_collection_property(name, value)
|
||||||
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||||
else:
|
else:
|
||||||
onnx_inputs[name] = value.numpy()
|
onnx_inputs[name] = value.numpy()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ if is_torch_available():
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
@@ -46,6 +47,7 @@ class FeaturesManager:
|
|||||||
_TASKS_TO_AUTOMODELS = {
|
_TASKS_TO_AUTOMODELS = {
|
||||||
"default": AutoModel,
|
"default": AutoModel,
|
||||||
"causal-lm": AutoModelForCausalLM,
|
"causal-lm": AutoModelForCausalLM,
|
||||||
|
"seq2seq-lm": AutoModelForSeq2SeqLM,
|
||||||
"sequence-classification": AutoModelForSequenceClassification,
|
"sequence-classification": AutoModelForSequenceClassification,
|
||||||
"token-classification": AutoModelForTokenClassification,
|
"token-classification": AutoModelForTokenClassification,
|
||||||
"multiple-choice": AutoModelForMultipleChoice,
|
"multiple-choice": AutoModelForMultipleChoice,
|
||||||
@@ -61,7 +63,9 @@ class FeaturesManager:
|
|||||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||||
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||||
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
|
"t5": supported_features_mapping(
|
||||||
|
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
||||||
|
),
|
||||||
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||||
"gpt-neo": supported_features_mapping(
|
"gpt-neo": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
from ctypes import c_float, sizeof
|
from ctypes import c_float, sizeof
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Iterable
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterFormat(Enum):
|
class ParameterFormat(Enum):
|
||||||
@@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
|||||||
Size (in byte) taken to save all the parameters
|
Size (in byte) taken to save all the parameters
|
||||||
"""
|
"""
|
||||||
return num_parameters * dtype.size
|
return num_parameters * dtype.size
|
||||||
|
|
||||||
|
|
||||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
|
||||||
structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The name of the nested structure
|
|
||||||
field: The structure to, potentially, be flattened
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
|
||||||
|
|||||||
@@ -34,11 +34,7 @@ from transformers.onnx import (
|
|||||||
validate_model_outputs,
|
validate_model_outputs,
|
||||||
)
|
)
|
||||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||||
from transformers.onnx.utils import (
|
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||||
compute_effective_axis_dimension,
|
|
||||||
compute_serialized_parameters_size,
|
|
||||||
flatten_output_collection_property,
|
|
||||||
)
|
|
||||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
|
|||||||
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
||||||
"""
|
"""
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||||
{
|
{
|
||||||
"past_key.0": 0,
|
"past_key.0": 0,
|
||||||
"past_key.1": 1,
|
"past_key.1": 1,
|
||||||
|
|||||||
Reference in New Issue
Block a user