Make sure all submodules are properly registered (#15144)
* Make sure all submodules are properly registered * Try to fix tests * Fix tests
This commit is contained in:
@@ -14,8 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
@@ -202,5 +204,58 @@ def check_all_inits():
|
||||
raise ValueError("\n\n".join(failures))
|
||||
|
||||
|
||||
def get_transformers_submodules():
|
||||
"""
|
||||
Returns the list of Transformers submodules.
|
||||
"""
|
||||
submodules = []
|
||||
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
for folder in directories:
|
||||
if folder.startswith("_"):
|
||||
directories.remove(folder)
|
||||
continue
|
||||
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(os.path.sep, ".")
|
||||
submodules.append(submodule)
|
||||
for fname in files:
|
||||
if fname == "__init__.py":
|
||||
continue
|
||||
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
|
||||
if len(submodule.split(".")) == 1:
|
||||
submodules.append(submodule)
|
||||
return submodules
|
||||
|
||||
|
||||
IGNORE_SUBMODULES = [
|
||||
"convert_pytorch_checkpoint_to_tf2",
|
||||
"modeling_flax_pytorch_utils",
|
||||
]
|
||||
|
||||
|
||||
def check_submodules():
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
module_not_registered = [
|
||||
module
|
||||
for module in get_transformers_submodules()
|
||||
if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys()
|
||||
]
|
||||
if len(module_not_registered) > 0:
|
||||
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
|
||||
raise ValueError(
|
||||
"The following submodules are not properly registed in the main init of Transformers:\n"
|
||||
f"{list_of_modules}\n"
|
||||
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_all_inits()
|
||||
check_submodules()
|
||||
|
||||
Reference in New Issue
Block a user