[tokenizers] convert_to_tensors: don't reconvert when the type is already right (#8283)

* don't reconvert when the type is already right

* better name

* adjust logic as suggested

* merge
This commit is contained in:
Stas Bekman
2020-11-19 12:06:01 -08:00
committed by GitHub
parent 20b658607e
commit 42111f1d56
2 changed files with 51 additions and 9 deletions

View File

@@ -20,7 +20,7 @@ import numpy as np
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow
class TokenizerUtilsTest(unittest.TestCase):
@@ -156,6 +156,10 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="np")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
@@ -168,6 +172,10 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="pt")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
@@ -180,12 +188,32 @@ class TokenizerUtilsTest(unittest.TestCase):
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="tf")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))
@require_flax
def test_batch_encoding_with_labels_jax(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
self.assertEqual(tensor_batch["labels"].shape, (2,))
# test converting the converted
with CaptureStderr() as cs:
tensor_batch = batch.convert_to_tensors(tensor_type="jax")
self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True)
self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
self.assertEqual(tensor_batch["labels"].shape, (1,))
def test_padding_accepts_tensors(self):
features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")