fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)

fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor."

Co-authored-by: 刘长伟 <hzliuchw@corp.netease.com>
This commit is contained in:
Leo
2023-07-26 21:07:21 +08:00
committed by GitHub
parent 04a5c859b0
commit c53c8e490c

View File

@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor
is_tensor = torch.is_tensor
def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.tensor(np.array(value))
return torch.tensor(value)
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")