Make public versions of private tensor utils (#19775)

* Make public versions of private utils

* I need sleep
This commit is contained in:
Sylvain Gugger
2022-10-21 09:34:01 -04:00
committed by GitHub
parent 3aaabaa214
commit 9151e649a5
8 changed files with 80 additions and 49 deletions

View File

@@ -45,16 +45,20 @@ from .utils import (
download_url,
extract_commit_hash,
is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tf_tensor,
is_tokenizers_available,
is_torch_available,
is_torch_device,
is_torch_tensor,
logging,
to_py_obj,
torch_required,
)
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
if TYPE_CHECKING:
@@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = _is_jax
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
is_tensor = _is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# f"Unable to convert output to tensors format {tensor_type}"
# )
is_tensor = is_numpy_array
# Do the tensor conversion in batch
for key, value in self.items():
@@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
break
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors