Transformers fast import part 2 (#9446)
* Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> * Add documentation * Styling Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -33,6 +33,7 @@ from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
@@ -41,7 +42,6 @@ import numpy as np
|
||||
from packaging import version
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import importlib_metadata
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
|
||||
@@ -50,6 +50,13 @@ from .hf_api import HfFolder
|
||||
from .utils import logging
|
||||
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if version.parse(sys.version) < version.parse("3.8"):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||
@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
_scatter_version = importlib_metadata.version("torch_scatterr")
|
||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scatter_available = False
|
||||
@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
|
||||
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
||||
"""
|
||||
return tuple(self[k] for k in self.keys())
|
||||
|
||||
|
||||
class _BaseLazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule
|
||||
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name, import_structure):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
for key, values in import_structure.items():
|
||||
for value in values:
|
||||
self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
return super().__dir__() + self.__all__
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
else:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str) -> ModuleType:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user