Fix onnx export input names order (#4641)
* pass on tokenizer to pipeline * order input names when convert to onnx * update style * remove unused imports * make ordered inputs list needs to be mutable * add test custom bert model * remove unused imports
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from itertools import takewhile
|
|
||||||
from os import listdir, makedirs
|
from os import listdir, makedirs
|
||||||
from os.path import abspath, dirname, exists
|
from os.path import abspath, dirname, exists
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@@ -38,14 +37,17 @@ def ensure_valid_input(model, tokens, input_names):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
model_args_name = model.forward.__code__.co_varnames
|
model_args_name = model.forward.__code__.co_varnames
|
||||||
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
|
|
||||||
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
|
|
||||||
|
|
||||||
for arg_pos, arg_name in model_args_pos:
|
ordered_input_names = []
|
||||||
model_args[arg_pos] = tokens[arg_name]
|
model_args = []
|
||||||
|
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
|
||||||
|
if arg_name in input_names:
|
||||||
|
ordered_input_names.append(arg_name)
|
||||||
|
model_args.append(tokens[arg_name])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
model_args = tuple(model_args) # Need to be ordered
|
return ordered_input_names, tuple(model_args)
|
||||||
return tuple(takewhile(lambda arg: arg is not None, model_args))
|
|
||||||
|
|
||||||
|
|
||||||
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
||||||
@@ -117,13 +119,13 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
||||||
model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
||||||
|
|
||||||
export(
|
export(
|
||||||
nlp.model,
|
nlp.model,
|
||||||
model_args,
|
model_args,
|
||||||
f=output,
|
f=output,
|
||||||
input_names=input_names,
|
input_names=ordered_input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from os import sep
|
|
||||||
from os.path import dirname, exists
|
from os.path import dirname, exists
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||||
|
|
||||||
from tests.utils import require_tf, require_torch, slow
|
from tests.utils import require_tf, require_torch, slow
|
||||||
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
||||||
@@ -33,17 +33,34 @@ class OnnxExportTestCase(unittest.TestCase):
|
|||||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||||
self._test_export(model, "pt", 11)
|
self._test_export(model, "pt", 11)
|
||||||
|
|
||||||
def _test_export(self, model, framework, opset):
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_export_custom_bert_model(self):
|
||||||
|
from transformers import BertModel
|
||||||
|
|
||||||
|
vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
|
||||||
|
with NamedTemporaryFile(mode="w+t") as vocab_file:
|
||||||
|
vocab_file.write("\n".join(vocab))
|
||||||
|
vocab_file.flush()
|
||||||
|
tokenizer = BertTokenizerFast(vocab_file.name)
|
||||||
|
|
||||||
|
with TemporaryDirectory() as bert_save_dir:
|
||||||
|
model = BertModel(BertConfig(vocab_size=len(vocab)))
|
||||||
|
model.save_pretrained(bert_save_dir)
|
||||||
|
self._test_export(bert_save_dir, "pt", 11, tokenizer)
|
||||||
|
|
||||||
|
def _test_export(self, model, framework, opset, tokenizer=None):
|
||||||
try:
|
try:
|
||||||
# Compute path
|
# Compute path
|
||||||
path = "onnx" + sep + model + ".onnx"
|
with TemporaryDirectory() as tempdir:
|
||||||
|
path = tempdir + "/model.onnx"
|
||||||
|
|
||||||
# Remove folder if exists
|
# Remove folder if exists
|
||||||
if exists(dirname(path)):
|
if exists(dirname(path)):
|
||||||
rmtree(dirname(path))
|
rmtree(dirname(path))
|
||||||
|
|
||||||
# Export
|
# Export
|
||||||
convert(framework, model, path, opset)
|
convert(framework, model, path, opset, tokenizer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(e)
|
self.fail(e)
|
||||||
|
|
||||||
@@ -99,20 +116,25 @@ class OnnxExportTestCase(unittest.TestCase):
|
|||||||
# All generated args are valid
|
# All generated args are valid
|
||||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||||
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
|
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
|
||||||
inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
|
ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
|
||||||
|
|
||||||
# Should have exactly the same number of args (all are valid)
|
# Should have exactly the same number of args (all are valid)
|
||||||
self.assertEqual(len(inputs_args), 3)
|
self.assertEqual(len(inputs_args), 3)
|
||||||
|
|
||||||
|
# Should have exactly the same input names
|
||||||
|
self.assertEqual(set(ordered_input_names), set(input_names))
|
||||||
|
|
||||||
# Parameter should be reordered according to their respective place in the function:
|
# Parameter should be reordered according to their respective place in the function:
|
||||||
# (input_ids, token_type_ids, attention_mask)
|
# (input_ids, token_type_ids, attention_mask)
|
||||||
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
|
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
|
||||||
|
|
||||||
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
|
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
|
||||||
inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
|
ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
|
||||||
|
|
||||||
# Should have exactly the one arg (all before the one not provided "some_other_args")
|
# Should have exactly the one arg (all before the one not provided "some_other_args")
|
||||||
self.assertEqual(len(inputs_args), 1)
|
self.assertEqual(len(inputs_args), 1)
|
||||||
|
self.assertEqual(len(ordered_input_names), 1)
|
||||||
|
|
||||||
# Should have only "input_ids"
|
# Should have only "input_ids"
|
||||||
self.assertEqual(inputs_args[0], tokens["input_ids"])
|
self.assertEqual(inputs_args[0], tokens["input_ids"])
|
||||||
|
self.assertEqual(ordered_input_names[0], "input_ids")
|
||||||
|
|||||||
Reference in New Issue
Block a user