Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)
* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer. * Added __get_state__() & __set_state__() to be pickable. * Correct tokens() return type from List[int] to List[str] * Added unittest for BatchEncoding pickle/unpickle * Added unittest for BatchEncoding is_fast * More careful checking on BatchEncoding unpickle tests. * Formatting. * is_fast should assertTrue on Rust tokenizers. * Ensure tensorflow has correct way of checking array_equal * More formatting.
This commit is contained in:
@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
|
||||
|
||||
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
|
||||
|
||||
@property
|
||||
def is_fast(self):
|
||||
"""
|
||||
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
|
||||
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
|
||||
"""
|
||||
return self._encodings is not None
|
||||
|
||||
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
|
||||
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
|
||||
If the key is an integer, get the EncodingFast for batch item with index `key`
|
||||
@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
|
||||
except KeyError:
|
||||
raise AttributeError
|
||||
|
||||
def __getstate__(self):
|
||||
return {"data": self.data, "encodings": self._encodings}
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "data" in state:
|
||||
self.data = state["data"]
|
||||
|
||||
if "encodings" in state:
|
||||
self._encodings = state["encodings"]
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
|
||||
"""
|
||||
return self._encodings
|
||||
|
||||
def tokens(self, batch_index: int = 0) -> List[int]:
|
||||
def tokens(self, batch_index: int = 0) -> List[str]:
|
||||
if not self._encodings:
|
||||
raise ValueError("tokens() is not available when using Python based tokenizers")
|
||||
return self._encodings[batch_index].tokens
|
||||
|
||||
Reference in New Issue
Block a user