fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)
fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor." Co-authored-by: 刘长伟 <hzliuchw@corp.netease.com>
This commit is contained in:
@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
import torch
|
||||
|
||||
as_tensor = torch.tensor
|
||||
is_tensor = torch.is_tensor
|
||||
|
||||
def as_tensor(value, dtype=None):
|
||||
if isinstance(value, list) and isinstance(value[0], np.ndarray):
|
||||
return torch.tensor(np.array(value))
|
||||
return torch.tensor(value)
|
||||
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||
|
||||
Reference in New Issue
Block a user