Fix dynamic module import error (#21646)
* fix dynamic module import error --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,9 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
@@ -143,9 +145,25 @@ def get_class_in_module(class_name, module_path):
|
|||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
module = importlib.import_module(module_path)
|
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
|
||||||
return getattr(module, class_name)
|
module_file_name = module_path.split(os.path.sep)[-1] + ".py"
|
||||||
|
|
||||||
|
# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
|
||||||
|
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
|
||||||
|
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
|
||||||
|
# On Windows, we need this character `r` before the path argument of `os.remove`
|
||||||
|
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
|
||||||
|
subprocess.run(["python", "-c", cmd])
|
||||||
|
|
||||||
|
# copy back the file that we want to import
|
||||||
|
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")
|
||||||
|
|
||||||
|
# import the module
|
||||||
|
module_path = module_path.replace(os.path.sep, ".")
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
def get_cached_module_file(
|
def get_cached_module_file(
|
||||||
@@ -212,7 +230,7 @@ def get_cached_module_file(
|
|||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
submodule = "local"
|
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
|
||||||
else:
|
else:
|
||||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
|
|
||||||
@@ -240,7 +258,7 @@ def get_cached_module_file(
|
|||||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||||
create_dynamic_module(full_submodule)
|
create_dynamic_module(full_submodule)
|
||||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||||
if submodule == "local":
|
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
|
||||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
# that hash, to only copy when there is a modification but it seems overkill for now).
|
||||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
||||||
|
|||||||
Reference in New Issue
Block a user