T5 with past ONNX export (#13014)

T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model

Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-08-06 15:46:26 +02:00
committed by GitHub
parent ee11224611
commit dc420b0eb1
8 changed files with 131 additions and 62 deletions

View File

@@ -34,11 +34,7 @@ from transformers.onnx import (
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow
@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,