Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
101
utils/check_tf_ops.py
Normal file
101
utils/check_tf_ops.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
REPO_PATH = "."
|
||||
|
||||
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"MergeV2Checkpoints",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"RestoreV2",
|
||||
"SaveV2",
|
||||
"ShardedFilename",
|
||||
"StatefulPartitionedCall",
|
||||
"StaticRegexFullMatch",
|
||||
"VarHandleOp",
|
||||
]
|
||||
|
||||
|
||||
def onnx_compliancy(saved_model_path, strict, opset):
|
||||
saved_model = SavedModel()
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
with open(saved_model_path, "rb") as f:
|
||||
saved_model.ParseFromString(f.read())
|
||||
|
||||
model_op_names = set()
|
||||
|
||||
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
|
||||
for meta_graph in saved_model.meta_graphs:
|
||||
# Add operations in the graph definition
|
||||
model_op_names.update(node.op for node in meta_graph.graph_def.node)
|
||||
|
||||
# Go through the functions in the graph definition
|
||||
for func in meta_graph.graph_def.library.function:
|
||||
# Add operations in each function
|
||||
model_op_names.update(node.op for node in func.node_def)
|
||||
|
||||
# Convert to list, sorted if you want
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
if strict and len(incompatible_ops) > 0:
|
||||
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
|
||||
elif len(incompatible_ops) > 0:
|
||||
print(f"Found the following incompatible ops for the opset {opset}:")
|
||||
print(*incompatible_ops, sep="\n")
|
||||
else:
|
||||
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
|
||||
parser.add_argument(
|
||||
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.framework == "onnx":
|
||||
onnx_compliancy(args.saved_model_path, args.strict, args.opset)
|
||||
245
utils/tf_ops/onnx.json
Normal file
245
utils/tf_ops/onnx.json
Normal file
@@ -0,0 +1,245 @@
|
||||
{
|
||||
"opsets": {
|
||||
"1": [
|
||||
"Abs",
|
||||
"Add",
|
||||
"AddV2",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchToSpaceND",
|
||||
"BiasAdd",
|
||||
"BiasAddV1",
|
||||
"Cast",
|
||||
"Ceil",
|
||||
"CheckNumerics",
|
||||
"ComplexAbs",
|
||||
"Concat",
|
||||
"ConcatV2",
|
||||
"Const",
|
||||
"ConstV2",
|
||||
"Conv1D",
|
||||
"Conv2D",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropInputV2",
|
||||
"DepthToSpace",
|
||||
"DepthwiseConv2d",
|
||||
"DepthwiseConv2dNative",
|
||||
"Div",
|
||||
"Dropout",
|
||||
"Elu",
|
||||
"Equal",
|
||||
"Erf",
|
||||
"Exp",
|
||||
"ExpandDims",
|
||||
"Flatten",
|
||||
"Floor",
|
||||
"Gather",
|
||||
"GatherNd",
|
||||
"GatherV2",
|
||||
"Greater",
|
||||
"Identity",
|
||||
"IdentityN",
|
||||
"If",
|
||||
"LRN",
|
||||
"LSTMBlockCell",
|
||||
"LeakyRelu",
|
||||
"Less",
|
||||
"Log",
|
||||
"LogSoftmax",
|
||||
"LogicalAnd",
|
||||
"LogicalNot",
|
||||
"LogicalOr",
|
||||
"LookupTableSizeV2",
|
||||
"MatMul",
|
||||
"Max",
|
||||
"MaxPool",
|
||||
"MaxPool3D",
|
||||
"MaxPoolV2",
|
||||
"Maximum",
|
||||
"Mean",
|
||||
"Min",
|
||||
"Minimum",
|
||||
"MirrorPad",
|
||||
"Mul",
|
||||
"Neg",
|
||||
"NoOp",
|
||||
"NotEqual",
|
||||
"OneHot",
|
||||
"Pack",
|
||||
"Pad",
|
||||
"PadV2",
|
||||
"Placeholder",
|
||||
"PlaceholderV2",
|
||||
"PlaceholderWithDefault",
|
||||
"Pow",
|
||||
"Prod",
|
||||
"RFFT",
|
||||
"RandomNormal",
|
||||
"RandomNormalLike",
|
||||
"RandomUniform",
|
||||
"RandomUniformLike",
|
||||
"RealDiv",
|
||||
"Reciprocal",
|
||||
"Relu",
|
||||
"Relu6",
|
||||
"Reshape",
|
||||
"Rsqrt",
|
||||
"Selu",
|
||||
"Shape",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Size",
|
||||
"Slice",
|
||||
"Softmax",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"SpaceToBatchND",
|
||||
"SpaceToDepth",
|
||||
"Split",
|
||||
"SplitV",
|
||||
"Sqrt",
|
||||
"Square",
|
||||
"SquaredDifference",
|
||||
"Squeeze",
|
||||
"StatelessIf",
|
||||
"StopGradient",
|
||||
"StridedSlice",
|
||||
"StringJoin",
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Tanh",
|
||||
"Tile",
|
||||
"TopKV2",
|
||||
"Transpose",
|
||||
"TruncateDiv",
|
||||
"Unpack",
|
||||
"ZerosLike"
|
||||
],
|
||||
"2": [],
|
||||
"3": [],
|
||||
"4": [],
|
||||
"5": [],
|
||||
"6": [
|
||||
"AddN",
|
||||
"All",
|
||||
"Any",
|
||||
"FloorDiv",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormV3"
|
||||
],
|
||||
"7": [
|
||||
"Acos",
|
||||
"Asin",
|
||||
"Atan",
|
||||
"Cos",
|
||||
"Fill",
|
||||
"FloorMod",
|
||||
"GreaterEqual",
|
||||
"LessEqual",
|
||||
"Loop",
|
||||
"MatrixBandPart",
|
||||
"Multinomial",
|
||||
"Range",
|
||||
"ResizeBilinear",
|
||||
"ResizeNearestNeighbor",
|
||||
"Scan",
|
||||
"Select",
|
||||
"SelectV2",
|
||||
"Sin",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"StatelessWhile",
|
||||
"Tan",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGetItem",
|
||||
"TensorListLength",
|
||||
"TensorListReserve",
|
||||
"TensorListResize",
|
||||
"TensorListSetItem",
|
||||
"TensorListStack",
|
||||
"While"
|
||||
],
|
||||
"8": [
|
||||
"BroadcastTo",
|
||||
"ClipByValue",
|
||||
"FIFOQueueV2",
|
||||
"HashTableV2",
|
||||
"IteratorGetNext",
|
||||
"IteratorV2",
|
||||
"LookupTableFindV2",
|
||||
"MaxPoolWithArgmax",
|
||||
"QueueDequeueManyV2",
|
||||
"QueueDequeueUpToV2",
|
||||
"QueueDequeueV2",
|
||||
"ReverseSequence"
|
||||
],
|
||||
"9": [
|
||||
"SegmentMax",
|
||||
"SegmentMean",
|
||||
"SegmentMin",
|
||||
"SegmentProd",
|
||||
"SegmentSum",
|
||||
"Sinh",
|
||||
"SparseSegmentMean",
|
||||
"SparseSegmentMeanWithNumSegments",
|
||||
"SparseSegmentSqrtN",
|
||||
"SparseSegmentSqrtNWithNumSegments",
|
||||
"SparseSegmentSum",
|
||||
"SparseSegmentSumWithNumSegments",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
"UnsortedSegmentSum",
|
||||
"Where"
|
||||
],
|
||||
"10": [
|
||||
"CropAndResize",
|
||||
"CudnnRNN",
|
||||
"DynamicStitch",
|
||||
"FakeQuantWithMinMaxArgs",
|
||||
"IsFinite",
|
||||
"IsInf",
|
||||
"NonMaxSuppressionV2",
|
||||
"NonMaxSuppressionV3",
|
||||
"NonMaxSuppressionV4",
|
||||
"NonMaxSuppressionV5",
|
||||
"ParallelDynamicStitch",
|
||||
"ReverseV2",
|
||||
"Roll"
|
||||
],
|
||||
"11": [
|
||||
"Bincount",
|
||||
"Cumsum",
|
||||
"InvertPermutation",
|
||||
"LeftShift",
|
||||
"MatrixDeterminant",
|
||||
"MatrixDiagPart",
|
||||
"MatrixDiagPartV2",
|
||||
"MatrixDiagPartV3",
|
||||
"RaggedRange",
|
||||
"RightShift",
|
||||
"Round",
|
||||
"ScatterNd",
|
||||
"SparseFillEmptyRows",
|
||||
"SparseReshape",
|
||||
"SparseToDense",
|
||||
"TensorScatterUpdate",
|
||||
"Unique"
|
||||
],
|
||||
"12": [
|
||||
"Einsum",
|
||||
"MatrixDiag",
|
||||
"MatrixDiagV2",
|
||||
"MatrixDiagV3",
|
||||
"MatrixSetDiagV3",
|
||||
"SquaredDistance"
|
||||
],
|
||||
"13": []
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user