[tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283)
* don't reconvert when the type is already right * better name * adjust logic as suggested * merge
This commit is contained in:
@@ -53,6 +53,15 @@ if is_torch_available():
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def _is_numpy(x):
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def _is_jax(x):
|
||||
return isinstance(x, jnp.ndarray)
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from tokenizers import AddedToken
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
@@ -705,16 +714,20 @@ class BatchEncoding(UserDict):
|
||||
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
||||
)
|
||||
as_tensor = tf.constant
|
||||
is_tensor = tf.is_tensor
|
||||
elif tensor_type == TensorType.PYTORCH:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
as_tensor = torch.tensor
|
||||
is_tensor = torch.is_tensor
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
is_tensor = _is_numpy
|
||||
# (mfuntowicz: This code is unreachable)
|
||||
# else:
|
||||
# raise ImportError(
|
||||
@@ -727,16 +740,17 @@ class BatchEncoding(UserDict):
|
||||
if prepend_batch_axis:
|
||||
value = [value]
|
||||
|
||||
tensor = as_tensor(value)
|
||||
if not is_tensor(value):
|
||||
tensor = as_tensor(value)
|
||||
|
||||
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
|
||||
# # at-least2d
|
||||
# if tensor.ndim > 2:
|
||||
# tensor = tensor.squeeze(0)
|
||||
# elif tensor.ndim < 2:
|
||||
# tensor = tensor[None, :]
|
||||
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
|
||||
# # at-least2d
|
||||
# if tensor.ndim > 2:
|
||||
# tensor = tensor.squeeze(0)
|
||||
# elif tensor.ndim < 2:
|
||||
# tensor = tensor[None, :]
|
||||
|
||||
self[key] = tensor
|
||||
self[key] = tensor
|
||||
except: # noqa E722
|
||||
if key == "overflowing_tokens":
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user