Make public versions of private tensor utils (#19775)
* Make public versions of private utils * I need sleep
This commit is contained in:
@@ -45,16 +45,20 @@ from .utils import (
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
is_flax_available,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_tf_tensor,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_device,
|
||||
is_torch_tensor,
|
||||
logging,
|
||||
to_py_obj,
|
||||
torch_required,
|
||||
)
|
||||
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
|
||||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
is_tensor = is_jax_tensor
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
is_tensor = _is_numpy
|
||||
# (mfuntowicz: This code is unreachable)
|
||||
# else:
|
||||
# raise ImportError(
|
||||
# f"Unable to convert output to tensors format {tensor_type}"
|
||||
# )
|
||||
is_tensor = is_numpy_array
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
@@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
|
||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
||||
@@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
break
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
if is_tf_tensor(first_element):
|
||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
elif is_torch_tensor(first_element):
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
elif isinstance(first_element, np.ndarray):
|
||||
return_tensors = "np" if return_tensors is None else return_tensors
|
||||
|
||||
Reference in New Issue
Block a user