From c53c8e490c158f505a9271f7f5d8248473da3d24 Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 26 Jul 2023 21:07:21 +0800 Subject: [PATCH] =?UTF-8?q?fix=20"UserWarning:=20Creating=20a=20tensor=20f?= =?UTF-8?q?rom=20a=20list=20of=20numpy.ndarrays=20is=20=E2=80=A6=20(#24772?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor." Co-authored-by: 刘长伟 --- src/transformers/tokenization_utils_base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f5271aff66..c2d242c9a3 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -700,8 +700,13 @@ class BatchEncoding(UserDict): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") import torch - as_tensor = torch.tensor is_tensor = torch.is_tensor + + def as_tensor(value, dtype=None): + if isinstance(value, list) and isinstance(value[0], np.ndarray): + return torch.tensor(np.array(value)) + return torch.tensor(value) + elif tensor_type == TensorType.JAX: if not is_flax_available(): raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")