Cleanup more auto mapping names (#21909)

* fix auto 2

* fix auto 2

* fix task guide issue

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-03 14:43:44 +01:00
committed by GitHub
parent b05e0bec88
commit 02a77fa04c
3 changed files with 30 additions and 18 deletions

View File

@@ -72,14 +72,24 @@ TASK_GUIDE_TO_MODELS = {
"document_question_answering.mdx": transformers_module.models.auto.modeling_auto.MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
}
# This list contains model types used in some task guides that are not in `CONFIG_MAPPING_NAMES` (therefore not in any
# `MODEL_MAPPING_NAMES` or any `MODEL_FOR_XXX_MAPPING_NAMES`).
SPECIAL_TASK_GUIDE_TO_MODEL_TYPES = {
"summarization.mdx": ("nllb",),
"translation.mdx": ("nllb",),
}
def get_model_list_for_task(task_guide):
"""
Return the list of models supporting given task.
"""
config_maping_names = TASK_GUIDE_TO_MODELS[task_guide]
model_maping_names = TASK_GUIDE_TO_MODELS[task_guide]
special_model_types = SPECIAL_TASK_GUIDE_TO_MODEL_TYPES.get(task_guide, set())
model_names = {
code: name for code, name in transformers_module.MODEL_NAMES_MAPPING.items() if code in config_maping_names
code: name
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
if (code in model_maping_names or code in special_model_types)
}
return ", ".join([f"[{name}](../model_doc/{code})" for code, name in model_names.items()]) + "\n"