Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)
* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer. * Added __get_state__() & __set_state__() to be pickable. * Correct tokens() return type from List[int] to List[str] * Added unittest for BatchEncoding pickle/unpickle * Added unittest for BatchEncoding is_fast * More careful checking on BatchEncoding unpickle tests. * Formatting. * is_fast should assertTrue on Rust tokenizers. * Ensure tensorflow has correct way of checking array_equal * More formatting.
This commit is contained in:
@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
|
|||||||
|
|
||||||
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
|
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_fast(self):
|
||||||
|
"""
|
||||||
|
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
|
||||||
|
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
|
||||||
|
"""
|
||||||
|
return self._encodings is not None
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
|
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
|
||||||
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
|
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
|
||||||
If the key is an integer, get the EncodingFast for batch item with index `key`
|
If the key is an integer, get the EncodingFast for batch item with index `key`
|
||||||
@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise AttributeError
|
raise AttributeError
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return {"data": self.data, "encodings": self._encodings}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
if "data" in state:
|
||||||
|
self.data = state["data"]
|
||||||
|
|
||||||
|
if "encodings" in state:
|
||||||
|
self._encodings = state["encodings"]
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self.data.keys()
|
return self.data.keys()
|
||||||
|
|
||||||
@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
|
|||||||
"""
|
"""
|
||||||
return self._encodings
|
return self._encodings
|
||||||
|
|
||||||
def tokens(self, batch_index: int = 0) -> List[int]:
|
def tokens(self, batch_index: int = 0) -> List[str]:
|
||||||
if not self._encodings:
|
if not self._encodings:
|
||||||
raise ValueError("tokens() is not available when using Python based tokenizers")
|
raise ValueError("tokens() is not available when using Python based tokenizers")
|
||||||
return self._encodings[batch_index].tokens
|
return self._encodings[batch_index].tokens
|
||||||
|
|||||||
@@ -12,14 +12,14 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import pickle
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, TensorType
|
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
|
||||||
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from .utils import slow
|
from .utils import require_tf, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
class TokenizerUtilsTest(unittest.TestCase):
|
class TokenizerUtilsTest(unittest.TestCase):
|
||||||
@@ -36,11 +36,103 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||||
self.assertIsInstance(special_tok_id, int)
|
self.assertIsInstance(special_tok_id, int)
|
||||||
|
|
||||||
|
def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None):
|
||||||
|
batch_encoding_str = pickle.dumps(be_original)
|
||||||
|
self.assertIsNotNone(batch_encoding_str)
|
||||||
|
|
||||||
|
be_restored = pickle.loads(batch_encoding_str)
|
||||||
|
|
||||||
|
# Ensure is_fast is correctly restored
|
||||||
|
self.assertEqual(be_restored.is_fast, be_original.is_fast)
|
||||||
|
|
||||||
|
# Ensure encodings are potentially correctly restored
|
||||||
|
if be_original.is_fast:
|
||||||
|
self.assertIsNotNone(be_restored.encodings)
|
||||||
|
else:
|
||||||
|
self.assertIsNone(be_restored.encodings)
|
||||||
|
|
||||||
|
# Ensure the keys are the same
|
||||||
|
for original_v, restored_v in zip(be_original.values(), be_restored.values()):
|
||||||
|
if equal_op:
|
||||||
|
self.assertTrue(equal_op(restored_v, original_v))
|
||||||
|
else:
|
||||||
|
self.assertEqual(restored_v, original_v)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_pretrained_tokenizers(self):
|
def test_pretrained_tokenizers(self):
|
||||||
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
||||||
|
|
||||||
def check_tensor_type_from_str(self):
|
def test_tensor_type_from_str(self):
|
||||||
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
|
self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
|
||||||
self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
|
self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
|
||||||
self.assertEqual(TensorType("np"), TensorType.NUMPY)
|
self.assertEqual(TensorType("np"), TensorType.NUMPY)
|
||||||
|
|
||||||
|
def test_batch_encoding_pickle(self):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
# Python no tensor
|
||||||
|
with self.subTest("BatchEncoding (Python, return_tensors=None)"):
|
||||||
|
self.assert_dump_and_restore(tokenizer_p("Small example to encode"))
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Rust, return_tensors=None)"):
|
||||||
|
self.assert_dump_and_restore(tokenizer_r("Small example to encode"))
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_batch_encoding_pickle_tf(self):
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
def tf_array_equals(t1, t2):
|
||||||
|
return tf.reduce_all(tf.equal(t1, t2))
|
||||||
|
|
||||||
|
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_batch_encoding_pickle_pt(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"):
|
||||||
|
self.assert_dump_and_restore(
|
||||||
|
tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_batch_encoding_is_fast(self):
|
||||||
|
tokenizer_p = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
with self.subTest("Python Tokenizer"):
|
||||||
|
self.assertFalse(tokenizer_p("Small example to_encode").is_fast)
|
||||||
|
|
||||||
|
with self.subTest("Rust Tokenizer"):
|
||||||
|
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
|
||||||
|
|||||||
Reference in New Issue
Block a user