Fast imports part 3 (#9474)
* New intermediate inits * Update template * Avoid importing torch/tf/flax in tokenization unless necessary * Styling * Shutup flake8 * Better python version check
This commit is contained in:
@@ -16,11 +16,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...file_utils import is_torch_available
|
||||
from .configuration_rag import RagConfig
|
||||
from .retrieval_rag import RagRetriever
|
||||
from .tokenization_rag import RagTokenizer
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_rag": ["RagConfig"],
|
||||
"retrieval_rag": ["RagRetriever"],
|
||||
"tokenization_rag": ["RagTokenizer"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_rag import RagConfig
|
||||
from .retrieval_rag import RagRetriever
|
||||
from .tokenization_rag import RagTokenizer
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
||||
|
||||
Reference in New Issue
Block a user