Remove hack for dynamic modules and use Python functions instead (#22537)
This commit is contained in:
@@ -13,14 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Utilities to dynamically load objects from the Hub."""
|
"""Utilities to dynamically load objects from the Hub."""
|
||||||
|
import filecmp
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
@@ -45,6 +43,7 @@ def init_hf_modules():
|
|||||||
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
||||||
if not init_path.exists():
|
if not init_path.exists():
|
||||||
init_path.touch()
|
init_path.touch()
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_module(name: Union[str, os.PathLike]):
|
def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||||
@@ -60,6 +59,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
|
|||||||
init_path = dynamic_module_path / "__init__.py"
|
init_path = dynamic_module_path / "__init__.py"
|
||||||
if not init_path.exists():
|
if not init_path.exists():
|
||||||
init_path.touch()
|
init_path.touch()
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
|
||||||
|
|
||||||
def get_relative_imports(module_file):
|
def get_relative_imports(module_file):
|
||||||
@@ -148,35 +148,9 @@ def get_class_in_module(class_name, module_path):
|
|||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
"""
|
"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
module_path = module_path.replace(os.path.sep, ".")
|
||||||
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
|
module = importlib.import_module(module_path)
|
||||||
module_file_name = module_path.split(os.path.sep)[-1] + ".py"
|
return getattr(module, class_name)
|
||||||
|
|
||||||
# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
|
|
||||||
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
|
|
||||||
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
|
|
||||||
# On Windows, we need this character `r` before the path argument of `os.remove`
|
|
||||||
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
|
|
||||||
# We don't know which python binary file exists in an environment. For example, if `python3` exists but not
|
|
||||||
# `python`, the call `subprocess.run(["python", ...])` gives `FileNotFoundError` (about python binary). Notice
|
|
||||||
# that, if the file to be removed is not found, we also have `FileNotFoundError`, but it is not raised to the
|
|
||||||
# caller's process.
|
|
||||||
try:
|
|
||||||
subprocess.run(["python", "-c", cmd])
|
|
||||||
except FileNotFoundError:
|
|
||||||
try:
|
|
||||||
subprocess.run(["python3", "-c", cmd])
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# copy back the file that we want to import
|
|
||||||
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")
|
|
||||||
|
|
||||||
# import the module
|
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
|
||||||
module = importlib.import_module(module_path)
|
|
||||||
|
|
||||||
return getattr(module, class_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cached_module_file(
|
def get_cached_module_file(
|
||||||
@@ -273,13 +247,21 @@ def get_cached_module_file(
|
|||||||
create_dynamic_module(full_submodule)
|
create_dynamic_module(full_submodule)
|
||||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||||
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
|
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
|
||||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
|
||||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
# has changed since last copy.
|
||||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
if not (submodule_path / module_file).exists() or not filecmp.cmp(
|
||||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
resolved_module_file, str(submodule_path / module_file)
|
||||||
|
):
|
||||||
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
|
importlib.invalidate_caches()
|
||||||
for module_needed in modules_needed:
|
for module_needed in modules_needed:
|
||||||
module_needed = f"{module_needed}.py"
|
module_needed = f"{module_needed}.py"
|
||||||
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
|
||||||
|
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
|
||||||
|
module_needed_file, str(submodule_path / module_needed)
|
||||||
|
):
|
||||||
|
shutil.copy(module_needed_file, submodule_path / module_needed)
|
||||||
|
importlib.invalidate_caches()
|
||||||
else:
|
else:
|
||||||
# Get the commit hash
|
# Get the commit hash
|
||||||
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
|
||||||
@@ -293,6 +275,7 @@ def get_cached_module_file(
|
|||||||
|
|
||||||
if not (submodule_path / module_file).exists():
|
if not (submodule_path / module_file).exists():
|
||||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
|
importlib.invalidate_caches()
|
||||||
# Make sure we also have every file with relative
|
# Make sure we also have every file with relative
|
||||||
for module_needed in modules_needed:
|
for module_needed in modules_needed:
|
||||||
if not (submodule_path / module_needed).exists():
|
if not (submodule_path / module_needed).exists():
|
||||||
|
|||||||
Reference in New Issue
Block a user