fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is … (#24772)
fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor." Co-authored-by: 刘长伟 <hzliuchw@corp.netease.com>
This commit is contained in:
@@ -700,8 +700,13 @@ class BatchEncoding(UserDict):
|
|||||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
as_tensor = torch.tensor
|
|
||||||
is_tensor = torch.is_tensor
|
is_tensor = torch.is_tensor
|
||||||
|
|
||||||
|
def as_tensor(value, dtype=None):
|
||||||
|
if isinstance(value, list) and isinstance(value[0], np.ndarray):
|
||||||
|
return torch.tensor(np.array(value))
|
||||||
|
return torch.tensor(value)
|
||||||
|
|
||||||
elif tensor_type == TensorType.JAX:
|
elif tensor_type == TensorType.JAX:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||||
|
|||||||
Reference in New Issue
Block a user