Fast imports part 3 (#9474)
* New intermediate inits * Update template * Avoid importing torch/tf/flax in tokenization unless necessary * Styling * Shutup flake8 * Better python version check
This commit is contained in:
@@ -25,7 +25,7 @@ import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -45,21 +45,34 @@ from .file_utils import (
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp # noqa: F401
|
||||
|
||||
|
||||
def _is_numpy(x):
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def _is_torch(x):
|
||||
import torch
|
||||
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
def _is_tensorflow(x):
|
||||
import tensorflow as tf
|
||||
|
||||
return isinstance(x, tf.Tensor)
|
||||
|
||||
|
||||
def _is_jax(x):
|
||||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
return isinstance(x, jnp.ndarray)
|
||||
|
||||
|
||||
@@ -196,9 +209,9 @@ def to_py_obj(obj):
|
||||
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [to_py_obj(o) for o in obj]
|
||||
elif is_tf_available() and isinstance(obj, tf.Tensor):
|
||||
elif is_tf_available() and _is_tensorflow(obj):
|
||||
return obj.numpy().tolist()
|
||||
elif is_torch_available() and isinstance(obj, torch.Tensor):
|
||||
elif is_torch_available() and _is_torch(obj):
|
||||
return obj.detach().cpu().tolist()
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
@@ -714,16 +727,22 @@ class BatchEncoding(UserDict):
|
||||
raise ImportError(
|
||||
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
||||
)
|
||||
import tensorflow as tf
|
||||
|
||||
as_tensor = tf.constant
|
||||
is_tensor = tf.is_tensor
|
||||
elif tensor_type == TensorType.PYTORCH:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
import torch
|
||||
|
||||
as_tensor = torch.tensor
|
||||
is_tensor = torch.is_tensor
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
else:
|
||||
@@ -2684,9 +2703,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
first_element = encoded_inputs["input_ids"][index][0]
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and isinstance(first_element, tf.Tensor):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||
elif is_torch_available() and isinstance(first_element, torch.Tensor):
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
elif isinstance(first_element, np.ndarray):
|
||||
return_tensors = "np" if return_tensors is None else return_tensors
|
||||
|
||||
Reference in New Issue
Block a user