T5 with past ONNX export (#13014)
T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
""" GPT Neo model configuration """
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
@@ -253,8 +253,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
for i in range(self._number_key_values):
|
||||
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||
else:
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
@@ -264,9 +268,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = super().outputs
|
||||
if self.use_past:
|
||||
for i in range(self._number_key_values):
|
||||
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
|
||||
|
||||
for i in range(self._config.num_layers):
|
||||
if self._config.attention_layers[i] == "local":
|
||||
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
|
||||
else:
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
|
||||
return common_outputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
@@ -315,3 +322,18 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
if len(t) == 1:
|
||||
flatten_output[f"{name}.{idx}.key_value"] = t[0]
|
||||
else:
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
||||
@@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
""" T5 model configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
|
||||
from ... import is_torch_available
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfigWithPast
|
||||
from ...utils import logging
|
||||
@@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
|
||||
|
||||
|
||||
class T5OnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||
super().__init__(config, use_past)
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict(
|
||||
@@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||
for i in range(0, self._config.num_layers):
|
||||
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
|
||||
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_outputs = OrderedDict(
|
||||
[
|
||||
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
|
||||
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
common_outputs = super().outputs
|
||||
|
||||
if "last_hidden_state" in common_outputs:
|
||||
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
for i in range(self._config.num_layers):
|
||||
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
|
||||
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
if self.task == "default":
|
||||
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
|
||||
|
||||
return common_outputs
|
||||
|
||||
@@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
if self.use_past:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Generate encoder inputs
|
||||
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||
@@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
|
||||
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
|
||||
return dict(**encoder_inputs, **decoder_inputs)
|
||||
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch = encoder_inputs["input_ids"].shape[0]
|
||||
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
|
||||
encoder_shape = (
|
||||
batch,
|
||||
self._config.num_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // self._config.num_heads,
|
||||
)
|
||||
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
|
||||
|
||||
ordered_inputs["past_key_values"] = []
|
||||
for _ in range(self._config.num_layers):
|
||||
ordered_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
|
||||
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
|
||||
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
||||
@@ -429,8 +429,6 @@ class T5Attention(nn.Module):
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
int_seq_length = int(seq_length)
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -499,7 +497,7 @@ class T5Attention(nn.Module):
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -int_seq_length:, :]
|
||||
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
@@ -629,7 +627,7 @@ class T5Block(nn.Module):
|
||||
if len(past_key_value) != expected_num_past_key_values:
|
||||
raise ValueError(
|
||||
f"There should be {expected_num_past_key_values} past states. "
|
||||
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
|
||||
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
||||
f"Got {len(past_key_value)} past key / value states"
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
|
||||
@@ -59,6 +59,7 @@ class OnnxConfig(ABC):
|
||||
_TASKS_TO_COMMON_OUTPUTS = {
|
||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||
@@ -228,6 +229,24 @@ class OnnxConfig(ABC):
|
||||
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
|
||||
setattr(spec.o, spec.name, orig_op)
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||
structure.
|
||||
|
||||
Args:
|
||||
name: The name of the nested structure
|
||||
field: The structure to, potentially, be flattened
|
||||
|
||||
Returns:
|
||||
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||
|
||||
"""
|
||||
from itertools import chain
|
||||
|
||||
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||
|
||||
|
||||
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
def __init__(
|
||||
@@ -285,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||
# Generate dummy inputs according to compute batch and sequence
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
||||
|
||||
@staticmethod
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
if name in ["present", "past_key_values"]:
|
||||
flatten_output = {}
|
||||
for idx, t in enumerate(field):
|
||||
flatten_output[f"{name}.{idx}.key"] = t[0]
|
||||
flatten_output[f"{name}.{idx}.value"] = t[1]
|
||||
|
||||
return flatten_output
|
||||
|
||||
return super().flatten_output_collection_property(name, field)
|
||||
|
||||
@@ -24,7 +24,6 @@ from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedMod
|
||||
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
||||
from ..utils import logging
|
||||
from .config import OnnxConfig
|
||||
from .utils import flatten_output_collection_property
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -163,7 +162,7 @@ def validate_model_outputs(
|
||||
if name == "past_key_values":
|
||||
name = "present"
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = flatten_output_collection_property(name, value)
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
ref_outputs_dict.update(value)
|
||||
else:
|
||||
ref_outputs_dict[name] = value
|
||||
@@ -172,7 +171,7 @@ def validate_model_outputs(
|
||||
onnx_inputs = {}
|
||||
for name, value in reference_model_inputs.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = flatten_output_collection_property(name, value)
|
||||
value = config.flatten_output_collection_property(name, value)
|
||||
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||
else:
|
||||
onnx_inputs[name] = value.numpy()
|
||||
|
||||
@@ -21,6 +21,7 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
)
|
||||
@@ -46,6 +47,7 @@ class FeaturesManager:
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": AutoModel,
|
||||
"causal-lm": AutoModelForCausalLM,
|
||||
"seq2seq-lm": AutoModelForSeq2SeqLM,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"multiple-choice": AutoModelForMultipleChoice,
|
||||
@@ -61,7 +63,9 @@ class FeaturesManager:
|
||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
|
||||
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
|
||||
"t5": supported_features_mapping(
|
||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
||||
),
|
||||
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
|
||||
"gpt-neo": supported_features_mapping(
|
||||
"default",
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
from ctypes import c_float, sizeof
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
|
||||
class ParameterFormat(Enum):
|
||||
@@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
||||
Size (in byte) taken to save all the parameters
|
||||
"""
|
||||
return num_parameters * dtype.size
|
||||
|
||||
|
||||
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||
structure.
|
||||
|
||||
Args:
|
||||
name: The name of the nested structure
|
||||
field: The structure to, potentially, be flattened
|
||||
|
||||
Returns:
|
||||
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||
|
||||
"""
|
||||
from itertools import chain
|
||||
|
||||
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||
|
||||
@@ -34,11 +34,7 @@ from transformers.onnx import (
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
compute_serialized_parameters_size,
|
||||
flatten_output_collection_property,
|
||||
)
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
|
||||
|
||||
@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
||||
"""
|
||||
self.assertEqual(
|
||||
flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
{
|
||||
"past_key.0": 0,
|
||||
"past_key.1": 1,
|
||||
|
||||
Reference in New Issue
Block a user