Make public versions of private tensor utils (#19775)
* Make public versions of private utils * I need sleep
This commit is contained in:
@@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||||
from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy
|
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
|
||||||
from .utils.generic import _is_tensorflow, _is_torch
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
|||||||
first_element = required_input[index][0]
|
first_element = required_input[index][0]
|
||||||
|
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
if is_tf_available() and _is_tensorflow(first_element):
|
if is_tf_tensor(first_element):
|
||||||
return_tensors = "tf"
|
return_tensors = "tf"
|
||||||
elif is_torch_available() and _is_torch(first_element):
|
elif is_torch_tensor(first_element):
|
||||||
return_tensors = "pt"
|
return_tensors = "pt"
|
||||||
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
|
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
|
||||||
return_tensors = "np"
|
return_tensors = "np"
|
||||||
|
|||||||
@@ -33,14 +33,16 @@ from .utils import (
|
|||||||
copy_func,
|
copy_func,
|
||||||
download_url,
|
download_url,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_jax_tensor,
|
||||||
|
is_numpy_array,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_device,
|
||||||
logging,
|
logging,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
from .utils.generic import _is_jax, _is_numpy, _is_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -150,10 +152,10 @@ class BatchFeature(UserDict):
|
|||||||
import jax.numpy as jnp # noqa: F811
|
import jax.numpy as jnp # noqa: F811
|
||||||
|
|
||||||
as_tensor = jnp.array
|
as_tensor = jnp.array
|
||||||
is_tensor = _is_jax
|
is_tensor = is_jax_tensor
|
||||||
else:
|
else:
|
||||||
as_tensor = np.asarray
|
as_tensor = np.asarray
|
||||||
is_tensor = _is_numpy
|
is_tensor = is_numpy_array
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
# Do the tensor conversion in batch
|
||||||
for key, value in self.items():
|
for key, value in self.items():
|
||||||
@@ -188,7 +190,7 @@ class BatchFeature(UserDict):
|
|||||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||||
# into a HalfTensor
|
# into a HalfTensor
|
||||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
|
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
|
||||||
|
|||||||
@@ -21,14 +21,21 @@ from packaging import version
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
from .utils import (
|
||||||
|
ExplicitEnum,
|
||||||
|
is_jax_tensor,
|
||||||
|
is_tf_tensor,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_tensor,
|
||||||
|
is_vision_available,
|
||||||
|
to_numpy,
|
||||||
|
)
|
||||||
from .utils.constants import ( # noqa: F401
|
from .utils.constants import ( # noqa: F401
|
||||||
IMAGENET_DEFAULT_MEAN,
|
IMAGENET_DEFAULT_MEAN,
|
||||||
IMAGENET_DEFAULT_STD,
|
IMAGENET_DEFAULT_STD,
|
||||||
IMAGENET_STANDARD_MEAN,
|
IMAGENET_STANDARD_MEAN,
|
||||||
IMAGENET_STANDARD_STD,
|
IMAGENET_STANDARD_STD,
|
||||||
)
|
)
|
||||||
from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy
|
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum):
|
|||||||
LAST = "channels_last"
|
LAST = "channels_last"
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tensor(obj):
|
|
||||||
return _is_torch(obj) if is_torch_available() else False
|
|
||||||
|
|
||||||
|
|
||||||
def is_tf_tensor(obj):
|
|
||||||
return _is_tensorflow(obj) if is_tf_available() else False
|
|
||||||
|
|
||||||
|
|
||||||
def is_jax_tensor(obj):
|
|
||||||
return _is_jax(obj) if is_flax_available() else False
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_image(img):
|
def is_valid_image(img):
|
||||||
return (
|
return (
|
||||||
isinstance(img, (PIL.Image.Image, np.ndarray))
|
isinstance(img, (PIL.Image.Image, np.ndarray))
|
||||||
|
|||||||
@@ -33,11 +33,9 @@ from ...tokenization_utils_base import (
|
|||||||
TextInput,
|
TextInput,
|
||||||
TextInputPair,
|
TextInputPair,
|
||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
_is_tensorflow,
|
|
||||||
_is_torch,
|
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
)
|
)
|
||||||
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer):
|
|||||||
first_element = required_input[index][0]
|
first_element = required_input[index][0]
|
||||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||||
if not isinstance(first_element, (int, list, tuple)):
|
if not isinstance(first_element, (int, list, tuple)):
|
||||||
if is_tf_available() and _is_tensorflow(first_element):
|
if is_tf_tensor(first_element):
|
||||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||||
elif is_torch_available() and _is_torch(first_element):
|
elif is_torch_tensor(first_element):
|
||||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||||
elif isinstance(first_element, np.ndarray):
|
elif isinstance(first_element, np.ndarray):
|
||||||
return_tensors = "np" if return_tensors is None else return_tensors
|
return_tensors = "np" if return_tensors is None else return_tensors
|
||||||
|
|||||||
@@ -37,11 +37,9 @@ from ...tokenization_utils_base import (
|
|||||||
TextInput,
|
TextInput,
|
||||||
TextInputPair,
|
TextInputPair,
|
||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
_is_tensorflow,
|
|
||||||
_is_torch,
|
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
)
|
)
|
||||||
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
|||||||
first_element = required_input[index][0]
|
first_element = required_input[index][0]
|
||||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||||
if not isinstance(first_element, (int, list, tuple)):
|
if not isinstance(first_element, (int, list, tuple)):
|
||||||
if is_tf_available() and _is_tensorflow(first_element):
|
if is_tf_tensor(first_element):
|
||||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||||
elif is_torch_available() and _is_torch(first_element):
|
elif is_torch_tensor(first_element):
|
||||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||||
elif isinstance(first_element, np.ndarray):
|
elif isinstance(first_element, np.ndarray):
|
||||||
return_tensors = "np" if return_tensors is None else return_tensors
|
return_tensors = "np" if return_tensors is None else return_tensors
|
||||||
|
|||||||
@@ -45,16 +45,20 @@ from .utils import (
|
|||||||
download_url,
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_jax_tensor,
|
||||||
|
is_numpy_array,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
|
is_tf_tensor,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_device,
|
||||||
|
is_torch_tensor,
|
||||||
logging,
|
logging,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
|
|||||||
import jax.numpy as jnp # noqa: F811
|
import jax.numpy as jnp # noqa: F811
|
||||||
|
|
||||||
as_tensor = jnp.array
|
as_tensor = jnp.array
|
||||||
is_tensor = _is_jax
|
is_tensor = is_jax_tensor
|
||||||
else:
|
else:
|
||||||
as_tensor = np.asarray
|
as_tensor = np.asarray
|
||||||
is_tensor = _is_numpy
|
is_tensor = is_numpy_array
|
||||||
# (mfuntowicz: This code is unreachable)
|
|
||||||
# else:
|
|
||||||
# raise ImportError(
|
|
||||||
# f"Unable to convert output to tensors format {tensor_type}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
# Do the tensor conversion in batch
|
||||||
for key, value in self.items():
|
for key, value in self.items():
|
||||||
@@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
|
|||||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||||
# into a HalfTensor
|
# into a HalfTensor
|
||||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
||||||
@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
break
|
break
|
||||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||||
if not isinstance(first_element, (int, list, tuple)):
|
if not isinstance(first_element, (int, list, tuple)):
|
||||||
if is_tf_available() and _is_tensorflow(first_element):
|
if is_tf_tensor(first_element):
|
||||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||||
elif is_torch_available() and _is_torch(first_element):
|
elif is_torch_tensor(first_element):
|
||||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||||
elif isinstance(first_element, np.ndarray):
|
elif isinstance(first_element, np.ndarray):
|
||||||
return_tensors = "np" if return_tensors is None else return_tensors
|
return_tensors = "np" if return_tensors is None else return_tensors
|
||||||
|
|||||||
@@ -40,7 +40,12 @@ from .generic import (
|
|||||||
cached_property,
|
cached_property,
|
||||||
find_labels,
|
find_labels,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
|
is_jax_tensor,
|
||||||
|
is_numpy_array,
|
||||||
is_tensor,
|
is_tensor,
|
||||||
|
is_tf_tensor,
|
||||||
|
is_torch_device,
|
||||||
|
is_torch_tensor,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
working_or_temp_dir,
|
working_or_temp_dir,
|
||||||
|
|||||||
@@ -83,30 +83,65 @@ def _is_numpy(x):
|
|||||||
return isinstance(x, np.ndarray)
|
return isinstance(x, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
def is_numpy_array(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a numpy array or not.
|
||||||
|
"""
|
||||||
|
return _is_numpy(x)
|
||||||
|
|
||||||
|
|
||||||
def _is_torch(x):
|
def _is_torch(x):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
return isinstance(x, torch.Tensor)
|
return isinstance(x, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_tensor(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
|
||||||
|
"""
|
||||||
|
return False if not is_torch_available() else _is_torch(x)
|
||||||
|
|
||||||
|
|
||||||
def _is_torch_device(x):
|
def _is_torch_device(x):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
return isinstance(x, torch.device)
|
return isinstance(x, torch.device)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_device(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
|
||||||
|
"""
|
||||||
|
return False if not is_torch_available() else _is_torch_device(x)
|
||||||
|
|
||||||
|
|
||||||
def _is_tensorflow(x):
|
def _is_tensorflow(x):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
return isinstance(x, tf.Tensor)
|
return isinstance(x, tf.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def is_tf_tensor(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
|
||||||
|
"""
|
||||||
|
return False if not is_tf_available() else _is_tensorflow(x)
|
||||||
|
|
||||||
|
|
||||||
def _is_jax(x):
|
def _is_jax(x):
|
||||||
import jax.numpy as jnp # noqa: F811
|
import jax.numpy as jnp # noqa: F811
|
||||||
|
|
||||||
return isinstance(x, jnp.ndarray)
|
return isinstance(x, jnp.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
def is_jax_tensor(x):
|
||||||
|
"""
|
||||||
|
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
|
||||||
|
"""
|
||||||
|
return False if not is_flax_available() else _is_jax(x)
|
||||||
|
|
||||||
|
|
||||||
def to_py_obj(obj):
|
def to_py_obj(obj):
|
||||||
"""
|
"""
|
||||||
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
||||||
@@ -115,11 +150,11 @@ def to_py_obj(obj):
|
|||||||
return {k: to_py_obj(v) for k, v in obj.items()}
|
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
return [to_py_obj(o) for o in obj]
|
return [to_py_obj(o) for o in obj]
|
||||||
elif is_tf_available() and _is_tensorflow(obj):
|
elif is_tf_tensor(obj):
|
||||||
return obj.numpy().tolist()
|
return obj.numpy().tolist()
|
||||||
elif is_torch_available() and _is_torch(obj):
|
elif is_torch_tensor(obj):
|
||||||
return obj.detach().cpu().tolist()
|
return obj.detach().cpu().tolist()
|
||||||
elif is_flax_available() and _is_jax(obj):
|
elif is_jax_tensor(obj):
|
||||||
return np.asarray(obj).tolist()
|
return np.asarray(obj).tolist()
|
||||||
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
|
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
|
||||||
return obj.tolist()
|
return obj.tolist()
|
||||||
@@ -135,11 +170,11 @@ def to_numpy(obj):
|
|||||||
return {k: to_numpy(v) for k, v in obj.items()}
|
return {k: to_numpy(v) for k, v in obj.items()}
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
return np.array(obj)
|
return np.array(obj)
|
||||||
elif is_tf_available() and _is_tensorflow(obj):
|
elif is_tf_tensor(obj):
|
||||||
return obj.numpy()
|
return obj.numpy()
|
||||||
elif is_torch_available() and _is_torch(obj):
|
elif is_torch_tensor(obj):
|
||||||
return obj.detach().cpu().numpy()
|
return obj.detach().cpu().numpy()
|
||||||
elif is_flax_available() and _is_jax(obj):
|
elif is_jax_tensor(obj):
|
||||||
return np.asarray(obj)
|
return np.asarray(obj)
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|||||||
Reference in New Issue
Block a user