From 42111f1d56947797d9dfb0908908f42a22ca9823 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 19 Nov 2020 12:06:01 -0800 Subject: [PATCH] [tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283) * don't reconvert when the type is already right * better name * adjust logic as suggested * merge --- src/transformers/tokenization_utils_base.py | 30 +++++++++++++++------ tests/test_tokenization_utils.py | 30 ++++++++++++++++++++- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0fbad8b74a..1b3bdd9a5d 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -53,6 +53,15 @@ if is_torch_available(): if is_flax_available(): import jax.numpy as jnp + +def _is_numpy(x): + return isinstance(x, np.ndarray) + + +def _is_jax(x): + return isinstance(x, jnp.ndarray) + + if is_tokenizers_available(): from tokenizers import AddedToken from tokenizers import Encoding as EncodingFast @@ -705,16 +714,20 @@ class BatchEncoding(UserDict): "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." ) as_tensor = tf.constant + is_tensor = tf.is_tensor elif tensor_type == TensorType.PYTORCH: if not is_torch_available(): raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") as_tensor = torch.tensor + is_tensor = torch.is_tensor elif tensor_type == TensorType.JAX: if not is_flax_available(): raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") as_tensor = jnp.array + is_tensor = _is_jax else: as_tensor = np.asarray + is_tensor = _is_numpy # (mfuntowicz: This code is unreachable) # else: # raise ImportError( @@ -727,16 +740,17 @@ class BatchEncoding(UserDict): if prepend_batch_axis: value = [value] - tensor = as_tensor(value) + if not is_tensor(value): + tensor = as_tensor(value) - # Removing this for now in favor of controlling the shape with `prepend_batch_axis` - # # at-least2d - # if tensor.ndim > 2: - # tensor = tensor.squeeze(0) - # elif tensor.ndim < 2: - # tensor = tensor[None, :] + # Removing this for now in favor of controlling the shape with `prepend_batch_axis` + # # at-least2d + # if tensor.ndim > 2: + # tensor = tensor.squeeze(0) + # elif tensor.ndim < 2: + # tensor = tensor[None, :] - self[key] = tensor + self[key] = tensor except: # noqa E722 if key == "overflowing_tokens": raise ValueError( diff --git a/tests/test_tokenization_utils.py b/tests/test_tokenization_utils.py index 05c6d19c32..b10369bb63 100644 --- a/tests/test_tokenization_utils.py +++ b/tests/test_tokenization_utils.py @@ -20,7 +20,7 @@ import numpy as np from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow +from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow class TokenizerUtilsTest(unittest.TestCase): @@ -156,6 +156,10 @@ class TokenizerUtilsTest(unittest.TestCase): tensor_batch = batch.convert_to_tensors(tensor_type="np") self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) self.assertEqual(tensor_batch["labels"].shape, (2,)) + # test converting the converted + with CaptureStderr() as cs: + tensor_batch = batch.convert_to_tensors(tensor_type="np") + self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}") batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True) @@ -168,6 +172,10 @@ class TokenizerUtilsTest(unittest.TestCase): tensor_batch = batch.convert_to_tensors(tensor_type="pt") self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) self.assertEqual(tensor_batch["labels"].shape, (2,)) + # test converting the converted + with CaptureStderr() as cs: + tensor_batch = batch.convert_to_tensors(tensor_type="pt") + self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}") batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True) @@ -180,12 +188,32 @@ class TokenizerUtilsTest(unittest.TestCase): tensor_batch = batch.convert_to_tensors(tensor_type="tf") self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) self.assertEqual(tensor_batch["labels"].shape, (2,)) + # test converting the converted + with CaptureStderr() as cs: + tensor_batch = batch.convert_to_tensors(tensor_type="tf") + self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}") batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True) self.assertEqual(tensor_batch["inputs"].shape, (1, 3)) self.assertEqual(tensor_batch["labels"].shape, (1,)) + @require_flax + def test_batch_encoding_with_labels_jax(self): + batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}) + tensor_batch = batch.convert_to_tensors(tensor_type="jax") + self.assertEqual(tensor_batch["inputs"].shape, (2, 3)) + self.assertEqual(tensor_batch["labels"].shape, (2,)) + # test converting the converted + with CaptureStderr() as cs: + tensor_batch = batch.convert_to_tensors(tensor_type="jax") + self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}") + + batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0}) + tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True) + self.assertEqual(tensor_batch["inputs"].shape, (1, 3)) + self.assertEqual(tensor_batch["labels"].shape, (1,)) + def test_padding_accepts_tensors(self): features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}] tokenizer = BertTokenizer.from_pretrained("bert-base-cased")