Clean up auto mapping names (#21903)
* add new test * fix after new test --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -43,7 +43,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("blenderbot", "BlenderbotModel"),
|
("blenderbot", "BlenderbotModel"),
|
||||||
("blenderbot-small", "BlenderbotSmallModel"),
|
("blenderbot-small", "BlenderbotSmallModel"),
|
||||||
("blip", "BlipModel"),
|
("blip", "BlipModel"),
|
||||||
("blip_2", "Blip2Model"),
|
("blip-2", "Blip2Model"),
|
||||||
("bloom", "BloomModel"),
|
("bloom", "BloomModel"),
|
||||||
("bridgetower", "BridgeTowerModel"),
|
("bridgetower", "BridgeTowerModel"),
|
||||||
("camembert", "CamembertModel"),
|
("camembert", "CamembertModel"),
|
||||||
@@ -64,7 +64,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta", "DebertaModel"),
|
("deberta", "DebertaModel"),
|
||||||
("deberta-v2", "DebertaV2Model"),
|
("deberta-v2", "DebertaV2Model"),
|
||||||
("decision_transformer", "DecisionTransformerModel"),
|
("decision_transformer", "DecisionTransformerModel"),
|
||||||
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
|
|
||||||
("deformable_detr", "DeformableDetrModel"),
|
("deformable_detr", "DeformableDetrModel"),
|
||||||
("deit", "DeiTModel"),
|
("deit", "DeiTModel"),
|
||||||
("deta", "DetaModel"),
|
("deta", "DetaModel"),
|
||||||
@@ -128,7 +127,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("mvp", "MvpModel"),
|
("mvp", "MvpModel"),
|
||||||
("nat", "NatModel"),
|
("nat", "NatModel"),
|
||||||
("nezha", "NezhaModel"),
|
("nezha", "NezhaModel"),
|
||||||
("nllb", "M2M100Model"),
|
|
||||||
("nystromformer", "NystromformerModel"),
|
("nystromformer", "NystromformerModel"),
|
||||||
("oneformer", "OneFormerModel"),
|
("oneformer", "OneFormerModel"),
|
||||||
("openai-gpt", "OpenAIGPTModel"),
|
("openai-gpt", "OpenAIGPTModel"),
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("hubert", "Wav2Vec2Processor"),
|
("hubert", "Wav2Vec2Processor"),
|
||||||
("layoutlmv2", "LayoutLMv2Processor"),
|
("layoutlmv2", "LayoutLMv2Processor"),
|
||||||
("layoutlmv3", "LayoutLMv3Processor"),
|
("layoutlmv3", "LayoutLMv3Processor"),
|
||||||
("layoutxlm", "LayoutXLMProcessor"),
|
|
||||||
("markuplm", "MarkupLMProcessor"),
|
("markuplm", "MarkupLMProcessor"),
|
||||||
("oneformer", "OneFormerProcessor"),
|
("oneformer", "OneFormerProcessor"),
|
||||||
("owlvit", "OwlViTProcessor"),
|
("owlvit", "OwlViTProcessor"),
|
||||||
@@ -72,7 +71,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
||||||
("wav2vec2", "Wav2Vec2Processor"),
|
("wav2vec2", "Wav2Vec2Processor"),
|
||||||
("wav2vec2-conformer", "Wav2Vec2Processor"),
|
("wav2vec2-conformer", "Wav2Vec2Processor"),
|
||||||
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
|
|
||||||
("wavlm", "Wav2Vec2Processor"),
|
("wavlm", "Wav2Vec2Processor"),
|
||||||
("whisper", "WhisperProcessor"),
|
("whisper", "WhisperProcessor"),
|
||||||
("xclip", "XCLIPProcessor"),
|
("xclip", "XCLIPProcessor"),
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
||||||
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
||||||
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
||||||
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
|
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
|
||||||
@@ -646,6 +647,31 @@ def check_all_auto_object_names_being_defined():
|
|||||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||||
|
|
||||||
|
|
||||||
|
def check_all_auto_mapping_names_in_config_mapping_names():
|
||||||
|
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
# `TOKENIZER_PROCESSOR_MAPPING_NAMES` and `AutoTokenizer` is special, and don't need to follow the rule.
|
||||||
|
mapping_to_check = {
|
||||||
|
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
|
||||||
|
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
|
||||||
|
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
|
||||||
|
"MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES,
|
||||||
|
"TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES,
|
||||||
|
"FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, mapping in mapping_to_check.items():
|
||||||
|
for model_type, class_names in mapping.items():
|
||||||
|
if model_type not in CONFIG_MAPPING_NAMES:
|
||||||
|
failures.append(
|
||||||
|
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
|
||||||
|
"`CONFIG_MAPPING_NAMES`."
|
||||||
|
)
|
||||||
|
if len(failures) > 0:
|
||||||
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||||
|
|
||||||
|
|
||||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||||
|
|
||||||
|
|
||||||
@@ -922,6 +948,8 @@ def check_repo_quality():
|
|||||||
check_all_models_are_auto_configured()
|
check_all_models_are_auto_configured()
|
||||||
print("Checking all names in auto name mappings are defined.")
|
print("Checking all names in auto name mappings are defined.")
|
||||||
check_all_auto_object_names_being_defined()
|
check_all_auto_object_names_being_defined()
|
||||||
|
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
|
||||||
|
check_all_auto_mapping_names_in_config_mapping_names()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user