From 7d83655da906b2e0021d9ecbf4df0298c6ae895a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 5 Oct 2021 22:43:16 -0400 Subject: [PATCH] Autodocument the list of ONNX-supported models (#13884) --- docs/source/serialization.rst | 8 +++- utils/check_table.py | 70 +++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 285e24cb56..53e075e13c 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures. Ready-made configurations include the following models: +.. + This table is automatically generated by make style, do not fill manually! + - ALBERT - BART - BERT - DistilBERT -- GPT-2 +- GPT Neo - LayoutLM +- Longformer +- mBART +- OpenAI GPT-2 - RoBERTa - T5 - XLM-RoBERTa diff --git a/utils/check_table.py b/utils/check_table.py index 6376360969..042d0a9cb6 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") +# This is to make sure the transformers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(TRANSFORMERS_PATH, "__init__.py"), + submodule_search_locations=[TRANSFORMERS_PATH], +) +transformers_module = spec.loader.load_module() + + # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python def camel_case_split(identifier): "Split a camelcased `identifier` into words." @@ -78,19 +87,11 @@ def _center_text(text, width): def get_model_table_from_auto_modules(): """Generates an up-to-date model table from the content of the auto modules.""" - # This is to make sure the transformers module imported is the one in the repo. - spec = importlib.util.spec_from_file_location( - "transformers", - os.path.join(TRANSFORMERS_PATH, "__init__.py"), - submodule_search_locations=[TRANSFORMERS_PATH], - ) - transformers = spec.loader.load_module() - # Dictionary model names to config. - config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES + config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES model_name_to_config = { name: config_maping_names[code] - for code, name in transformers.MODEL_NAMES_MAPPING.items() + for code, name in transformers_module.MODEL_NAMES_MAPPING.items() if code in config_maping_names } model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()} @@ -103,7 +104,7 @@ def get_model_table_from_auto_modules(): flax_models = collections.defaultdict(bool) # Let's lookup through all transformers object (once). - for attr_name in dir(transformers): + for attr_name in dir(transformers_module): lookup_dict = None if attr_name.endswith("Tokenizer"): lookup_dict = slow_tokenizers @@ -178,9 +179,56 @@ def check_model_table(overwrite=False): ) +def has_onnx(model_type): + """ + Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not. + """ + config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING + if model_type not in config_mapping: + return False + config = config_mapping[model_type] + config_module = config.__module__ + module = transformers_module + for part in config_module.split(".")[1:]: + module = getattr(module, part) + config_name = config.__name__ + onnx_config_name = config_name.replace("Config", "OnnxConfig") + return hasattr(module, onnx_config_name) + + +def get_onnx_model_list(): + """ + Return the list of models supporting ONNX. + """ + config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING + model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING + onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)] + onnx_model_names = [model_names[model_type] for model_type in onnx_model_types] + onnx_model_names.sort(key=lambda x: x.upper()) + return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n" + + +def check_onnx_model_list(overwrite=False): + """Check the model list in the serialization.rst is consistent with the state of the lib and maybe `overwrite`.""" + current_list, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "serialization.rst"), + start_prompt=" This table is automatically generated by make style, do not fill manually!", + end_prompt="This conversion is handled with the PyTorch version of models ", + ) + new_list = get_onnx_model_list() + + if current_list != new_list: + if overwrite: + with open(os.path.join(PATH_TO_DOCS, "serialization.rst"), "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines[:start_index] + [new_list] + lines[end_index:]) + else: + raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() check_model_table(args.fix_and_overwrite) + check_onnx_model_list(args.fix_and_overwrite)