[tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283)

* don't reconvert when the type is already right

* better name

* adjust logic as suggested

* merge
This commit is contained in:
Stas Bekman
2020-11-19 12:06:01 -08:00
committed by GitHub
parent 20b658607e
commit 42111f1d56
2 changed files with 51 additions and 9 deletions

View File

@@ -53,6 +53,15 @@ if is_torch_available():
if is_flax_available():
import jax.numpy as jnp
def _is_numpy(x):
return isinstance(x, np.ndarray)
def _is_jax(x):
return isinstance(x, jnp.ndarray)
if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
@@ -705,16 +714,20 @@ class BatchEncoding(UserDict):
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
as_tensor = jnp.array
is_tensor = _is_jax
else:
as_tensor = np.asarray
is_tensor = _is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
@@ -727,16 +740,17 @@ class BatchEncoding(UserDict):
if prepend_batch_axis:
value = [value]
tensor = as_tensor(value)
if not is_tensor(value):
tensor = as_tensor(value)
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]
self[key] = tensor
self[key] = tensor
except: # noqa E722
if key == "overflowing_tokens":
raise ValueError(