Fix onnx export input names order (#4641)
* pass on tokenizer to pipeline * order input names when convert to onnx * update style * remove unused imports * make ordered inputs list needs to be mutable * add test custom bert model * remove unused imports
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from argparse import ArgumentParser
|
||||
from itertools import takewhile
|
||||
from os import listdir, makedirs
|
||||
from os.path import abspath, dirname, exists
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@@ -38,14 +37,17 @@ def ensure_valid_input(model, tokens, input_names):
|
||||
|
||||
"""
|
||||
model_args_name = model.forward.__code__.co_varnames
|
||||
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
|
||||
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
|
||||
|
||||
for arg_pos, arg_name in model_args_pos:
|
||||
model_args[arg_pos] = tokens[arg_name]
|
||||
ordered_input_names = []
|
||||
model_args = []
|
||||
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
|
||||
if arg_name in input_names:
|
||||
ordered_input_names.append(arg_name)
|
||||
model_args.append(tokens[arg_name])
|
||||
else:
|
||||
break
|
||||
|
||||
model_args = tuple(model_args) # Need to be ordered
|
||||
return tuple(takewhile(lambda arg: arg is not None, model_args))
|
||||
return ordered_input_names, tuple(model_args)
|
||||
|
||||
|
||||
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
||||
@@ -117,13 +119,13 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
|
||||
|
||||
with torch.no_grad():
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
||||
model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
||||
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
||||
|
||||
export(
|
||||
nlp.model,
|
||||
model_args,
|
||||
f=output,
|
||||
input_names=input_names,
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
|
||||
Reference in New Issue
Block a user