From c21c3737c1d64aa15a5853366818efc72694cdc9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 17 Jul 2023 12:53:03 -0400 Subject: [PATCH] Add TAPEX to the list of deprecated models (#24859) * Add TAPEX to the list of deprecated models * Add check * Fix typo * Fix import path for Van conversion --- .../models/auto/configuration_auto.py | 1 + .../deprecated/van/convert_van_to_pytorch.py | 2 +- utils/check_repo.py | 28 +++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 893a319a02..0ae8e34e4a 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -646,6 +646,7 @@ DEPRECATED_MODELS = [ "mctct", "mmbt", "retribert", + "tapex", "trajectory_transformer", "van", ] diff --git a/src/transformers/models/deprecated/van/convert_van_to_pytorch.py b/src/transformers/models/deprecated/van/convert_van_to_pytorch.py index c9bbc0fc91..20492e42be 100644 --- a/src/transformers/models/deprecated/van/convert_van_to_pytorch.py +++ b/src/transformers/models/deprecated/van/convert_van_to_pytorch.py @@ -31,7 +31,7 @@ from huggingface_hub import cached_download, hf_hub_download from torch import Tensor from transformers import AutoImageProcessor, VanConfig, VanForImageClassification -from transformers.models.van.modeling_van import VanLayerScaling +from transformers.models.deprecated.van.modeling_van import VanLayerScaling from transformers.utils import logging diff --git a/utils/check_repo.py b/utils/check_repo.py index 5a3184b10a..64bd343abf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -1076,6 +1076,32 @@ def check_docstrings_are_in_md(): ) +def check_deprecated_constant_is_up_to_date(): + deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated") + deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")] + + constant_to_check = transformers.models.auto.configuration_auto.DEPRECATED_MODELS + message = [] + missing_models = sorted(set(deprecated_models) - set(constant_to_check)) + if len(missing_models) != 0: + missing_models = ", ".join(missing_models) + message.append( + "The following models are in the deprecated folder, make sur to add them to `DEPRECATED_MODELS` in " + f"`models/auto/configuration_auto.py`: {missing_models}." + ) + + extra_models = sorted(set(constant_to_check) - set(deprecated_models)) + if len(extra_models) != 0: + extra_models = ", ".join(extra_models) + message.append( + "The following models are in the `DEPRECATED_MODELS` constant but not in the deprecated folder. Either " + f"remove them from the constant or move to the deprecated folder: {extra_models}." + ) + + if len(message) > 0: + raise Exception("\n".join(message)) + + def check_repo_quality(): """Check all models are properly tested and documented.""" print("Checking all models are included.") @@ -1097,6 +1123,8 @@ def check_repo_quality(): check_all_auto_mappings_importable() print("Checking all objects are equally (across frameworks) in the main __init__.") check_objects_being_equally_in_main_init() + print("Checking the DEPRECATED_MODELS constant is up to date.") + check_deprecated_constant_is_up_to_date() if __name__ == "__main__":