fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)

fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor."

Co-authored-by: 刘长伟 <hzliuchw@corp.netease.com>
This commit is contained in:
Leo
2023-07-26 21:07:21 +08:00
committed by GitHub
parent 04a5c859b0
commit c53c8e490c

View File

@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor is_tensor = torch.is_tensor
def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.tensor(np.array(value))
return torch.tensor(value)
elif tensor_type == TensorType.JAX: elif tensor_type == TensorType.JAX:
if not is_flax_available(): if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")