Fast imports part 3 (#9474)

* New intermediate inits

* Update template

* Avoid importing torch/tf/flax in tokenization unless necessary

* Styling

* Shutup flake8

* Better python version check
This commit is contained in:
Sylvain Gugger
2021-01-08 07:40:59 -05:00
committed by GitHub
parent 79bbcc5260
commit 1bdf42409c
50 changed files with 3205 additions and 828 deletions

View File

@@ -16,11 +16,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...file_utils import is_torch_available
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available
_import_structure = {
"configuration_rag": ["RagConfig"],
"retrieval_rag": ["RagRetriever"],
"tokenization_rag": ["RagTokenizer"],
}
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
if TYPE_CHECKING:
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
from .tokenization_rag import RagTokenizer
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)