From 9e033649991d3785e1d1e9e1da87ba37bb223503 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Tue, 16 Jun 2020 09:25:25 +0200 Subject: [PATCH] Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039) * Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer. * Added __get_state__() & __set_state__() to be pickable. * Correct tokens() return type from List[int] to List[str] * Added unittest for BatchEncoding pickle/unpickle * Added unittest for BatchEncoding is_fast * More careful checking on BatchEncoding unpickle tests. * Formatting. * is_fast should assertTrue on Rust tokenizers. * Ensure tensorflow has correct way of checking array_equal * More formatting. --- src/transformers/tokenization_utils_base.py | 20 +++- tests/test_tokenization_utils.py | 102 +++++++++++++++++++- 2 files changed, 116 insertions(+), 6 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 71d01be78b..64e3634372 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -155,6 +155,14 @@ class BatchEncoding(UserDict): self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis) + @property + def is_fast(self): + """ + Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast + Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise + """ + return self._encodings is not None + def __getitem__(self, item: Union[int, str]) -> EncodingFast: """ If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...) If the key is an integer, get the EncodingFast for batch item with index `key` @@ -175,6 +183,16 @@ class BatchEncoding(UserDict): except KeyError: raise AttributeError + def __getstate__(self): + return {"data": self.data, "encodings": self._encodings} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + if "encodings" in state: + self._encodings = state["encodings"] + def keys(self): return self.data.keys() @@ -197,7 +215,7 @@ class BatchEncoding(UserDict): """ return self._encodings - def tokens(self, batch_index: int = 0) -> List[int]: + def tokens(self, batch_index: int = 0) -> List[str]: if not self._encodings: raise ValueError("tokens() is not available when using Python based tokenizers") return self._encodings[batch_index].tokens diff --git a/tests/test_tokenization_utils.py b/tests/test_tokenization_utils.py index de0ac69ac0..fb3f677d9e 100644 --- a/tests/test_tokenization_utils.py +++ b/tests/test_tokenization_utils.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import pickle import unittest +from typing import Callable, Optional -from transformers import PreTrainedTokenizer, TensorType +from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType from transformers.tokenization_gpt2 import GPT2Tokenizer -from .utils import slow +from .utils import require_tf, require_torch, slow class TokenizerUtilsTest(unittest.TestCase): @@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase): special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) self.assertIsInstance(special_tok_id, int) + def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None): + batch_encoding_str = pickle.dumps(be_original) + self.assertIsNotNone(batch_encoding_str) + + be_restored = pickle.loads(batch_encoding_str) + + # Ensure is_fast is correctly restored + self.assertEqual(be_restored.is_fast, be_original.is_fast) + + # Ensure encodings are potentially correctly restored + if be_original.is_fast: + self.assertIsNotNone(be_restored.encodings) + else: + self.assertIsNone(be_restored.encodings) + + # Ensure the keys are the same + for original_v, restored_v in zip(be_original.values(), be_restored.values()): + if equal_op: + self.assertTrue(equal_op(restored_v, original_v)) + else: + self.assertEqual(restored_v, original_v) + @slow def test_pretrained_tokenizers(self): self.check_tokenizer_from_pretrained(GPT2Tokenizer) - def check_tensor_type_from_str(self): + def test_tensor_type_from_str(self): self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW) self.assertEqual(TensorType("pt"), TensorType.PYTORCH) self.assertEqual(TensorType("np"), TensorType.NUMPY) + + def test_batch_encoding_pickle(self): + import numpy as np + + tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased") + tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased") + + # Python no tensor + with self.subTest("BatchEncoding (Python, return_tensors=None)"): + self.assert_dump_and_restore(tokenizer_p("Small example to encode")) + + with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"): + self.assert_dump_and_restore( + tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal + ) + + with self.subTest("BatchEncoding (Rust, return_tensors=None)"): + self.assert_dump_and_restore(tokenizer_r("Small example to encode")) + + with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"): + self.assert_dump_and_restore( + tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal + ) + + @require_tf + def test_batch_encoding_pickle_tf(self): + import tensorflow as tf + + def tf_array_equals(t1, t2): + return tf.reduce_all(tf.equal(t1, t2)) + + tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased") + tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased") + + with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"): + self.assert_dump_and_restore( + tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals + ) + + with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"): + self.assert_dump_and_restore( + tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals + ) + + @require_torch + def test_batch_encoding_pickle_pt(self): + import torch + + tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased") + tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased") + + with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"): + self.assert_dump_and_restore( + tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal + ) + + with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"): + self.assert_dump_and_restore( + tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal + ) + + def test_batch_encoding_is_fast(self): + tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased") + tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased") + + with self.subTest("Python Tokenizer"): + self.assertFalse(tokenizer_p("Small example to_encode").is_fast) + + with self.subTest("Rust Tokenizer"): + self.assertTrue(tokenizer_r("Small example to_encode").is_fast)