Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)
* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer. * Added __get_state__() & __set_state__() to be pickable. * Correct tokens() return type from List[int] to List[str] * Added unittest for BatchEncoding pickle/unpickle * Added unittest for BatchEncoding is_fast * More careful checking on BatchEncoding unpickle tests. * Formatting. * is_fast should assertTrue on Rust tokenizers. * Ensure tensorflow has correct way of checking array_equal * More formatting.
This commit is contained in:
@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
|
||||
|
||||
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
|
||||
|
||||
@property
|
||||
def is_fast(self):
|
||||
"""
|
||||
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
|
||||
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
|
||||
"""
|
||||
return self._encodings is not None
|
||||
|
||||
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
|
||||
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
|
||||
If the key is an integer, get the EncodingFast for batch item with index `key`
|
||||
@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
|
||||
except KeyError:
|
||||
raise AttributeError
|
||||
|
||||
def __getstate__(self):
|
||||
return {"data": self.data, "encodings": self._encodings}
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "data" in state:
|
||||
self.data = state["data"]
|
||||
|
||||
if "encodings" in state:
|
||||
self._encodings = state["encodings"]
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
|
||||
"""
|
||||
return self._encodings
|
||||
|
||||
def tokens(self, batch_index: int = 0) -> List[int]:
|
||||
def tokens(self, batch_index: int = 0) -> List[str]:
|
||||
if not self._encodings:
|
||||
raise ValueError("tokens() is not available when using Python based tokenizers")
|
||||
return self._encodings[batch_index].tokens
|
||||
|
||||
@@ -12,14 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import pickle
|
||||
import unittest
|
||||
from typing import Callable, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
|
||||
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from .utils import slow
|
||||
from .utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
class TokenizerUtilsTest(unittest.TestCase):
|
||||
@@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||
self.assertIsInstance(special_tok_id, int)
|
||||
|
||||
def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None):
|
||||
batch_encoding_str = pickle.dumps(be_original)
|
||||
self.assertIsNotNone(batch_encoding_str)
|
||||
|
||||
be_restored = pickle.loads(batch_encoding_str)
|
||||
|
||||
# Ensure is_fast is correctly restored
|
||||
self.assertEqual(be_restored.is_fast, be_original.is_fast)
|
||||
|
||||
# Ensure encodings are potentially correctly restored
|
||||
if be_original.is_fast:
|
||||
self.assertIsNotNone(be_restored.encodings)
|
||||
else:
|
||||
self.assertIsNone(be_restored.encodings)
|
||||
|
||||
# Ensure the keys are the same
|
||||
for original_v, restored_v in zip(be_original.values(), be_restored.values()):
|
||||
if equal_op:
|
||||
self.assertTrue(equal_op(restored_v, original_v))
|
||||
else:
|
||||
self.assertEqual(restored_v, original_v)
|
||||
|
||||
@slow
|
||||
def test_pretrained_tokenizers(self):
|
||||
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
||||
|
||||
def check_tensor_type_from_str(self):
|
||||
def test_tensor_type_from_str(self):
|
||||
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
|
||||
self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
|
||||
self.assertEqual(TensorType("np"), TensorType.NUMPY)
|
||||
|
||||
def test_batch_encoding_pickle(self):
|
||||
import numpy as np
|
||||
|
||||
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
|
||||
# Python no tensor
|
||||
with self.subTest("BatchEncoding (Python, return_tensors=None)"):
|
||||
self.assert_dump_and_restore(tokenizer_p("Small example to encode"))
|
||||
|
||||
with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
|
||||
)
|
||||
|
||||
with self.subTest("BatchEncoding (Rust, return_tensors=None)"):
|
||||
self.assert_dump_and_restore(tokenizer_r("Small example to encode"))
|
||||
|
||||
with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_batch_encoding_pickle_tf(self):
|
||||
import tensorflow as tf
|
||||
|
||||
def tf_array_equals(t1, t2):
|
||||
return tf.reduce_all(tf.equal(t1, t2))
|
||||
|
||||
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
|
||||
with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
|
||||
)
|
||||
|
||||
with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_batch_encoding_pickle_pt(self):
|
||||
import torch
|
||||
|
||||
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
|
||||
with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
|
||||
)
|
||||
|
||||
with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"):
|
||||
self.assert_dump_and_restore(
|
||||
tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
|
||||
)
|
||||
|
||||
def test_batch_encoding_is_fast(self):
|
||||
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
|
||||
with self.subTest("Python Tokenizer"):
|
||||
self.assertFalse(tokenizer_p("Small example to_encode").is_fast)
|
||||
|
||||
with self.subTest("Rust Tokenizer"):
|
||||
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
|
||||
|
||||
Reference in New Issue
Block a user