T5 with past ONNX export (#13014)
T5 with past ONNX export, and more explicit past_key_values inputs and outputs names for ONNX model Authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -34,11 +34,7 @@ from transformers.onnx import (
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
compute_serialized_parameters_size,
|
||||
flatten_output_collection_property,
|
||||
)
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
|
||||
|
||||
@@ -95,7 +91,7 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
||||
"""
|
||||
self.assertEqual(
|
||||
flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||
{
|
||||
"past_key.0": 0,
|
||||
"past_key.1": 1,
|
||||
|
||||
Reference in New Issue
Block a user