From 28fcf006076f28b37ba3879356811347577053db Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 4 Apr 2023 09:20:13 -0400 Subject: [PATCH] Remove hack for dynamic modules and use Python functions instead (#22537) --- src/transformers/dynamic_module_utils.py | 57 +++++++++--------------- 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 26d877d804..046980ad13 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities to dynamically load objects from the Hub.""" - +import filecmp import importlib import os import re import shutil -import subprocess import sys -import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -45,6 +43,7 @@ def init_hf_modules(): init_path = Path(HF_MODULES_CACHE) / "__init__.py" if not init_path.exists(): init_path.touch() + importlib.invalidate_caches() def create_dynamic_module(name: Union[str, os.PathLike]): @@ -60,6 +59,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]): init_path = dynamic_module_path / "__init__.py" if not init_path.exists(): init_path.touch() + importlib.invalidate_caches() def get_relative_imports(module_file): @@ -148,35 +148,9 @@ def get_class_in_module(class_name, module_path): """ Import a module on the cache directory for modules and extract a class from it. """ - with tempfile.TemporaryDirectory() as tmp_dir: - module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path) - module_file_name = module_path.split(os.path.sep)[-1] + ".py" - - # Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error - # `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'` - shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir) - # On Windows, we need this character `r` before the path argument of `os.remove` - cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")' - # We don't know which python binary file exists in an environment. For example, if `python3` exists but not - # `python`, the call `subprocess.run(["python", ...])` gives `FileNotFoundError` (about python binary). Notice - # that, if the file to be removed is not found, we also have `FileNotFoundError`, but it is not raised to the - # caller's process. - try: - subprocess.run(["python", "-c", cmd]) - except FileNotFoundError: - try: - subprocess.run(["python3", "-c", cmd]) - except FileNotFoundError: - pass - - # copy back the file that we want to import - shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}") - - # import the module - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) - - return getattr(module, class_name) + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + return getattr(module, class_name) def get_cached_module_file( @@ -273,13 +247,21 @@ def get_cached_module_file( create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]: - # We always copy local files (we could hash the file to see if there was a change, and give them the name of - # that hash, to only copy when there is a modification but it seems overkill for now). - # The only reason we do the copy is to avoid putting too many folders in sys.path. - shutil.copy(resolved_module_file, submodule_path / module_file) + # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or + # has changed since last copy. + if not (submodule_path / module_file).exists() or not filecmp.cmp( + resolved_module_file, str(submodule_path / module_file) + ): + shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() for module_needed in modules_needed: module_needed = f"{module_needed}.py" - shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) + if not (submodule_path / module_needed).exists() or not filecmp.cmp( + module_needed_file, str(submodule_path / module_needed) + ): + shutil.copy(module_needed_file, submodule_path / module_needed) + importlib.invalidate_caches() else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. @@ -293,6 +275,7 @@ def get_cached_module_file( if not (submodule_path / module_file).exists(): shutil.copy(resolved_module_file, submodule_path / module_file) + importlib.invalidate_caches() # Make sure we also have every file with relative for module_needed in modules_needed: if not (submodule_path / module_needed).exists():