From 9151e649a53fc8a7e5d8beec1ae8d27db1094aa7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 21 Oct 2022 09:34:01 -0400 Subject: [PATCH] Make public versions of private tensor utils (#19775) * Make public versions of private utils * I need sleep --- .../feature_extraction_sequence_utils.py | 7 ++- src/transformers/feature_extraction_utils.py | 10 ++-- src/transformers/image_utils.py | 23 ++++----- .../models/luke/tokenization_luke.py | 8 ++-- .../models/mluke/tokenization_mluke.py | 8 ++-- src/transformers/tokenization_utils_base.py | 21 ++++----- src/transformers/utils/__init__.py | 5 ++ src/transformers/utils/generic.py | 47 ++++++++++++++++--- 8 files changed, 80 insertions(+), 49 deletions(-) diff --git a/src/transformers/feature_extraction_sequence_utils.py b/src/transformers/feature_extraction_sequence_utils.py index 0415686803..1b869e4d6b 100644 --- a/src/transformers/feature_extraction_sequence_utils.py +++ b/src/transformers/feature_extraction_sequence_utils.py @@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union import numpy as np from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy -from .utils.generic import _is_tensorflow, _is_torch +from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy logger = logging.get_logger(__name__) @@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): first_element = required_input[index][0] if return_tensors is None: - if is_tf_available() and _is_tensorflow(first_element): + if is_tf_tensor(first_element): return_tensors = "tf" - elif is_torch_available() and _is_torch(first_element): + elif is_torch_tensor(first_element): return_tensors = "pt" elif isinstance(first_element, (int, float, list, tuple, np.ndarray)): return_tensors = "np" diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 85c751b841..41abfa2a27 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -33,14 +33,16 @@ from .utils import ( copy_func, download_url, is_flax_available, + is_jax_tensor, + is_numpy_array, is_offline_mode, is_remote_url, is_tf_available, is_torch_available, + is_torch_device, logging, torch_required, ) -from .utils.generic import _is_jax, _is_numpy, _is_torch_device if TYPE_CHECKING: @@ -150,10 +152,10 @@ class BatchFeature(UserDict): import jax.numpy as jnp # noqa: F811 as_tensor = jnp.array - is_tensor = _is_jax + is_tensor = is_jax_tensor else: as_tensor = np.asarray - is_tensor = _is_numpy + is_tensor = is_numpy_array # Do the tensor conversion in batch for key, value in self.items(): @@ -188,7 +190,7 @@ class BatchFeature(UserDict): # This check catches things like APEX blindly calling "to" on all inputs to a module # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor - if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): self.data = {k: v.to(device=device) for k, v in self.data.items()} else: logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.") diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index e9af6ac082..42c67a5138 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -21,14 +21,21 @@ from packaging import version import requests -from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available +from .utils import ( + ExplicitEnum, + is_jax_tensor, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_vision_available, + to_numpy, +) from .utils.constants import ( # noqa: F401 IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ) -from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy if is_vision_available(): @@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum): LAST = "channels_last" -def is_torch_tensor(obj): - return _is_torch(obj) if is_torch_available() else False - - -def is_tf_tensor(obj): - return _is_tensorflow(obj) if is_tf_available() else False - - -def is_jax_tensor(obj): - return _is_jax(obj) if is_flax_available() else False - - def is_valid_image(img): return ( isinstance(img, (PIL.Image.Image, np.ndarray)) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 3cbc9218c0..98931ddb6f 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -33,11 +33,9 @@ from ...tokenization_utils_base import ( TextInput, TextInputPair, TruncationStrategy, - _is_tensorflow, - _is_torch, to_py_obj, ) -from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging +from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging logger = logging.get_logger(__name__) @@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer): first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_available() and _is_tensorflow(first_element): + if is_tf_tensor(first_element): return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_available() and _is_torch(first_element): + elif is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 57272c391f..cafb0aee01 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -37,11 +37,9 @@ from ...tokenization_utils_base import ( TextInput, TextInputPair, TruncationStrategy, - _is_tensorflow, - _is_torch, to_py_obj, ) -from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging +from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging logger = logging.get_logger(__name__) @@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer): first_element = required_input[index][0] # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_available() and _is_tensorflow(first_element): + if is_tf_tensor(first_element): return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_available() and _is_torch(first_element): + elif is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b37ed59ef3..cacb5285ef 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -45,16 +45,20 @@ from .utils import ( download_url, extract_commit_hash, is_flax_available, + is_jax_tensor, + is_numpy_array, is_offline_mode, is_remote_url, is_tf_available, + is_tf_tensor, is_tokenizers_available, is_torch_available, + is_torch_device, + is_torch_tensor, logging, to_py_obj, torch_required, ) -from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device if TYPE_CHECKING: @@ -696,15 +700,10 @@ class BatchEncoding(UserDict): import jax.numpy as jnp # noqa: F811 as_tensor = jnp.array - is_tensor = _is_jax + is_tensor = is_jax_tensor else: as_tensor = np.asarray - is_tensor = _is_numpy - # (mfuntowicz: This code is unreachable) - # else: - # raise ImportError( - # f"Unable to convert output to tensors format {tensor_type}" - # ) + is_tensor = is_numpy_array # Do the tensor conversion in batch for key, value in self.items(): @@ -753,7 +752,7 @@ class BatchEncoding(UserDict): # This check catches things like APEX blindly calling "to" on all inputs to a module # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor - if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): + if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): self.data = {k: v.to(device=device) for k, v in self.data.items()} else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") @@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): break # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. if not isinstance(first_element, (int, list, tuple)): - if is_tf_available() and _is_tensorflow(first_element): + if is_tf_tensor(first_element): return_tensors = "tf" if return_tensors is None else return_tensors - elif is_torch_available() and _is_torch(first_element): + elif is_torch_tensor(first_element): return_tensors = "pt" if return_tensors is None else return_tensors elif isinstance(first_element, np.ndarray): return_tensors = "np" if return_tensors is None else return_tensors diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 97e013fee5..7857339379 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -40,7 +40,12 @@ from .generic import ( cached_property, find_labels, flatten_dict, + is_jax_tensor, + is_numpy_array, is_tensor, + is_tf_tensor, + is_torch_device, + is_torch_tensor, to_numpy, to_py_obj, working_or_temp_dir, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 9e8ae759d9..a53f769f05 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -83,30 +83,65 @@ def _is_numpy(x): return isinstance(x, np.ndarray) +def is_numpy_array(x): + """ + Tests if `x` is a numpy array or not. + """ + return _is_numpy(x) + + def _is_torch(x): import torch return isinstance(x, torch.Tensor) +def is_torch_tensor(x): + """ + Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch(x) + + def _is_torch_device(x): import torch return isinstance(x, torch.device) +def is_torch_device(x): + """ + Tests if `x` is a torch device or not. Safe to call even if torch is not installed. + """ + return False if not is_torch_available() else _is_torch_device(x) + + def _is_tensorflow(x): import tensorflow as tf return isinstance(x, tf.Tensor) +def is_tf_tensor(x): + """ + Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed. + """ + return False if not is_tf_available() else _is_tensorflow(x) + + def _is_jax(x): import jax.numpy as jnp # noqa: F811 return isinstance(x, jnp.ndarray) +def is_jax_tensor(x): + """ + Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed. + """ + return False if not is_flax_available() else _is_jax(x) + + def to_py_obj(obj): """ Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. @@ -115,11 +150,11 @@ def to_py_obj(obj): return {k: to_py_obj(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [to_py_obj(o) for o in obj] - elif is_tf_available() and _is_tensorflow(obj): + elif is_tf_tensor(obj): return obj.numpy().tolist() - elif is_torch_available() and _is_torch(obj): + elif is_torch_tensor(obj): return obj.detach().cpu().tolist() - elif is_flax_available() and _is_jax(obj): + elif is_jax_tensor(obj): return np.asarray(obj).tolist() elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays return obj.tolist() @@ -135,11 +170,11 @@ def to_numpy(obj): return {k: to_numpy(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return np.array(obj) - elif is_tf_available() and _is_tensorflow(obj): + elif is_tf_tensor(obj): return obj.numpy() - elif is_torch_available() and _is_torch(obj): + elif is_torch_tensor(obj): return obj.detach().cpu().numpy() - elif is_flax_available() and _is_jax(obj): + elif is_jax_tensor(obj): return np.asarray(obj) else: return obj