Fast imports part 3 (#9474)

* New intermediate inits

* Update template

* Avoid importing torch/tf/flax in tokenization unless necessary

* Styling

* Shutup flake8

* Better python version check
This commit is contained in:
Sylvain Gugger
2021-01-08 07:40:59 -05:00
committed by GitHub
parent 79bbcc5260
commit 1bdf42409c
50 changed files with 3205 additions and 828 deletions

View File

@@ -25,7 +25,7 @@ import warnings
from collections import OrderedDict, UserDict
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
@@ -45,21 +45,34 @@ from .file_utils import (
from .utils import logging
if is_tf_available():
import tensorflow as tf
if is_torch_available():
import torch
if is_flax_available():
import jax.numpy as jnp
if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp # noqa: F401
def _is_numpy(x):
return isinstance(x, np.ndarray)
def _is_torch(x):
import torch
return isinstance(x, torch.Tensor)
def _is_tensorflow(x):
import tensorflow as tf
return isinstance(x, tf.Tensor)
def _is_jax(x):
import jax.numpy as jnp # noqa: F811
return isinstance(x, jnp.ndarray)
@@ -196,9 +209,9 @@ def to_py_obj(obj):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and isinstance(obj, tf.Tensor):
elif is_tf_available() and _is_tensorflow(obj):
return obj.numpy().tolist()
elif is_torch_available() and isinstance(obj, torch.Tensor):
elif is_torch_available() and _is_torch(obj):
return obj.detach().cpu().tolist()
elif isinstance(obj, np.ndarray):
return obj.tolist()
@@ -714,16 +727,22 @@ class BatchEncoding(UserDict):
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = _is_jax
else:
@@ -2684,9 +2703,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
first_element = encoded_inputs["input_ids"][index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and isinstance(first_element, tf.Tensor):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and isinstance(first_element, torch.Tensor):
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors