Transformers fast import part 2 (#9446)

* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2021-01-07 09:36:14 -05:00
committed by GitHub
parent a400fe8931
commit 758ed3332b
56 changed files with 2426 additions and 1377 deletions

View File

@@ -33,6 +33,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
@@ -41,7 +42,6 @@ import numpy as np
from packaging import version
from tqdm.auto import tqdm
import importlib_metadata
import requests
from filelock import FileLock
@@ -50,6 +50,13 @@ from .hf_api import HfFolder
from .utils import logging
# The package importlib_metadata is in a different place, depending on the python version.
if version.parse(sys.version) < version.parse("3.8"):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatterr")
_scatter_version = importlib_metadata.version("torch_scatter")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_scatter_available = False
@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
Convert self to a tuple containing all the attributes/keys that are not ``None``.
"""
return tuple(self[k] for k in self.keys())
class _BaseLazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, import_structure):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
# Needed for autocompletion in an IDE
def __dir__(self):
return super().__dir__() + self.__all__
def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError