Autodocument the list of ONNX-supported models (#13884)
This commit is contained in:
@@ -37,12 +37,18 @@ architectures, and are made to be easily extendable to other architectures.
|
|||||||
|
|
||||||
Ready-made configurations include the following models:
|
Ready-made configurations include the following models:
|
||||||
|
|
||||||
|
..
|
||||||
|
This table is automatically generated by make style, do not fill manually!
|
||||||
|
|
||||||
- ALBERT
|
- ALBERT
|
||||||
- BART
|
- BART
|
||||||
- BERT
|
- BERT
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- GPT-2
|
- GPT Neo
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
- Longformer
|
||||||
|
- mBART
|
||||||
|
- OpenAI GPT-2
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- T5
|
- T5
|
||||||
- XLM-RoBERTa
|
- XLM-RoBERTa
|
||||||
|
|||||||
@@ -62,6 +62,15 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
|
|||||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||||
|
|
||||||
|
|
||||||
|
# This is to make sure the transformers module imported is the one in the repo.
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"transformers",
|
||||||
|
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||||
|
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||||
|
)
|
||||||
|
transformers_module = spec.loader.load_module()
|
||||||
|
|
||||||
|
|
||||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||||
def camel_case_split(identifier):
|
def camel_case_split(identifier):
|
||||||
"Split a camelcased `identifier` into words."
|
"Split a camelcased `identifier` into words."
|
||||||
@@ -78,19 +87,11 @@ def _center_text(text, width):
|
|||||||
|
|
||||||
def get_model_table_from_auto_modules():
|
def get_model_table_from_auto_modules():
|
||||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||||
# This is to make sure the transformers module imported is the one in the repo.
|
|
||||||
spec = importlib.util.spec_from_file_location(
|
|
||||||
"transformers",
|
|
||||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
|
||||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
|
||||||
)
|
|
||||||
transformers = spec.loader.load_module()
|
|
||||||
|
|
||||||
# Dictionary model names to config.
|
# Dictionary model names to config.
|
||||||
config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||||
model_name_to_config = {
|
model_name_to_config = {
|
||||||
name: config_maping_names[code]
|
name: config_maping_names[code]
|
||||||
for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||||
if code in config_maping_names
|
if code in config_maping_names
|
||||||
}
|
}
|
||||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||||
@@ -103,7 +104,7 @@ def get_model_table_from_auto_modules():
|
|||||||
flax_models = collections.defaultdict(bool)
|
flax_models = collections.defaultdict(bool)
|
||||||
|
|
||||||
# Let's lookup through all transformers object (once).
|
# Let's lookup through all transformers object (once).
|
||||||
for attr_name in dir(transformers):
|
for attr_name in dir(transformers_module):
|
||||||
lookup_dict = None
|
lookup_dict = None
|
||||||
if attr_name.endswith("Tokenizer"):
|
if attr_name.endswith("Tokenizer"):
|
||||||
lookup_dict = slow_tokenizers
|
lookup_dict = slow_tokenizers
|
||||||
@@ -178,9 +179,56 @@ def check_model_table(overwrite=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def has_onnx(model_type):
|
||||||
|
"""
|
||||||
|
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
|
||||||
|
"""
|
||||||
|
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
|
||||||
|
if model_type not in config_mapping:
|
||||||
|
return False
|
||||||
|
config = config_mapping[model_type]
|
||||||
|
config_module = config.__module__
|
||||||
|
module = transformers_module
|
||||||
|
for part in config_module.split(".")[1:]:
|
||||||
|
module = getattr(module, part)
|
||||||
|
config_name = config.__name__
|
||||||
|
onnx_config_name = config_name.replace("Config", "OnnxConfig")
|
||||||
|
return hasattr(module, onnx_config_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_onnx_model_list():
|
||||||
|
"""
|
||||||
|
Return the list of models supporting ONNX.
|
||||||
|
"""
|
||||||
|
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
|
||||||
|
model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
|
||||||
|
onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)]
|
||||||
|
onnx_model_names = [model_names[model_type] for model_type in onnx_model_types]
|
||||||
|
onnx_model_names.sort(key=lambda x: x.upper())
|
||||||
|
return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def check_onnx_model_list(overwrite=False):
|
||||||
|
"""Check the model list in the serialization.rst is consistent with the state of the lib and maybe `overwrite`."""
|
||||||
|
current_list, start_index, end_index, lines = _find_text_in_file(
|
||||||
|
filename=os.path.join(PATH_TO_DOCS, "serialization.rst"),
|
||||||
|
start_prompt=" This table is automatically generated by make style, do not fill manually!",
|
||||||
|
end_prompt="This conversion is handled with the PyTorch version of models ",
|
||||||
|
)
|
||||||
|
new_list = get_onnx_model_list()
|
||||||
|
|
||||||
|
if current_list != new_list:
|
||||||
|
if overwrite:
|
||||||
|
with open(os.path.join(PATH_TO_DOCS, "serialization.rst"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
f.writelines(lines[:start_index] + [new_list] + lines[end_index:])
|
||||||
|
else:
|
||||||
|
raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
check_model_table(args.fix_and_overwrite)
|
check_model_table(args.fix_and_overwrite)
|
||||||
|
check_onnx_model_list(args.fix_and_overwrite)
|
||||||
|
|||||||
Reference in New Issue
Block a user