T5 with past ONNX export (#13014)

T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model

Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-08-06 15:46:26 +02:00
committed by GitHub
parent ee11224611
commit dc420b0eb1
8 changed files with 131 additions and 62 deletions

View File

@@ -15,7 +15,7 @@
""" GPT Neo model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
@@ -253,8 +253,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._number_key_values):
common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i]
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_inputs[f"past_key_values.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
@@ -264,9 +268,12 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
for i in range(self._number_key_values):
common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i]
for i in range(self._config.num_layers):
if self._config.attention_layers[i] == "local":
common_outputs[f"present.{i}.key_value"] = {0: "batch", 1: "sequence"}
else:
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "sequence"}
return common_outputs
def generate_dummy_inputs(
@@ -315,3 +322,18 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
if len(t) == 1:
flatten_output[f"{name}.{idx}.key_value"] = t[0]
else:
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@@ -14,10 +14,11 @@
# limitations under the License.
""" T5 model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType
from ... import is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
@@ -140,9 +141,6 @@ class T5Config(PretrainedConfig):
class T5OnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config, use_past)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
@@ -155,29 +153,30 @@ class T5OnnxConfig(OnnxConfigWithPast):
)
if self.use_past:
for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
for i in range(0, self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"}
common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
]
)
common_outputs = super().outputs
if "last_hidden_state" in common_outputs:
common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"}
common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"}
common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"}
if self.task == "default":
common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"}
return common_outputs
@@ -189,8 +188,6 @@ class T5OnnxConfig(OnnxConfigWithPast):
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.use_past:
raise NotImplementedError()
# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
@@ -199,4 +196,45 @@ class T5OnnxConfig(OnnxConfigWithPast):
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
return dict(**encoder_inputs, **decoder_inputs)
ordered_inputs = dict(**encoder_inputs, **decoder_inputs)
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = encoder_inputs["input_ids"].shape[0]
encoder_seq_length = encoder_inputs["input_ids"].shape[1]
encoder_shape = (
batch,
self._config.num_heads,
encoder_seq_length,
self._config.hidden_size // self._config.num_heads,
)
decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads)
ordered_inputs["past_key_values"] = []
for _ in range(self._config.num_layers):
ordered_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)
return ordered_inputs
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.decoder.key"] = t[0]
flatten_output[f"{name}.{idx}.decoder.value"] = t[1]
flatten_output[f"{name}.{idx}.encoder.key"] = t[2]
flatten_output[f"{name}.{idx}.encoder.value"] = t[3]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@@ -429,8 +429,6 @@ class T5Attention(nn.Module):
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
int_seq_length = int(seq_length)
real_seq_length = seq_length
if past_key_value is not None:
@@ -499,7 +497,7 @@ class T5Attention(nn.Module):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -int_seq_length:, :]
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
@@ -629,7 +627,7 @@ class T5Block(nn.Module):
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)

View File

@@ -14,7 +14,7 @@
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, List, Mapping, Optional
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
@@ -59,6 +59,7 @@ class OnnxConfig(ABC):
_TASKS_TO_COMMON_OUTPUTS = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
@@ -228,6 +229,24 @@ class OnnxConfig(ABC):
orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
setattr(spec.o, spec.name, orig_op)
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
class OnnxConfigWithPast(OnnxConfig, ABC):
def __init__(
@@ -285,3 +304,15 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
@staticmethod
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
if name in ["present", "past_key_values"]:
flatten_output = {}
for idx, t in enumerate(field):
flatten_output[f"{name}.{idx}.key"] = t[0]
flatten_output[f"{name}.{idx}.value"] = t[1]
return flatten_output
return super().flatten_output_collection_property(name, field)

View File

@@ -24,7 +24,6 @@ from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedMod
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging
from .config import OnnxConfig
from .utils import flatten_output_collection_property
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -163,7 +162,7 @@ def validate_model_outputs(
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
ref_outputs_dict.update(value)
else:
ref_outputs_dict[name] = value
@@ -172,7 +171,7 @@ def validate_model_outputs(
onnx_inputs = {}
for name, value in reference_model_inputs.items():
if isinstance(value, (list, tuple)):
value = flatten_output_collection_property(name, value)
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()

View File

@@ -21,6 +21,7 @@ if is_torch_available():
AutoModelForCausalLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
@@ -46,6 +47,7 @@ class FeaturesManager:
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"causal-lm": AutoModelForCausalLM,
"seq2seq-lm": AutoModelForSeq2SeqLM,
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
@@ -61,7 +63,9 @@ class FeaturesManager:
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
"roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig),
"t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
),
"xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig),
"gpt-neo": supported_features_mapping(
"default",

View File

@@ -14,7 +14,6 @@
from ctypes import c_float, sizeof
from enum import Enum
from typing import Any, Dict, Iterable
class ParameterFormat(Enum):
@@ -62,21 +61,3 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
Size (in byte) taken to save all the parameters
"""
return num_parameters * dtype.size
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
"""
Flatten any potential nested structure expanding the name of the field with the index of the element within the
structure.
Args:
name: The name of the nested structure
field: The structure to, potentially, be flattened
Returns:
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
"""
from itertools import chain
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}

View File

@@ -34,11 +34,7 @@ from transformers.onnx import (
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow
@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,