Ability to pickle/unpickle BatchEncoding pickle (reimport) (#5039)

* Added is_fast property on BatchEncoding to indicate if the object comes from a Fast Tokenizer.

* Added __get_state__() & __set_state__() to be pickable.

* Correct tokens() return type from List[int] to List[str]

* Added unittest for BatchEncoding pickle/unpickle

* Added unittest for BatchEncoding is_fast

* More careful checking on BatchEncoding unpickle tests.

* Formatting.

* is_fast should assertTrue on Rust tokenizers.

* Ensure tensorflow has correct way of checking array_equal

* More formatting.
This commit is contained in:
Funtowicz Morgan
2020-06-16 09:25:25 +02:00
committed by GitHub
parent f9f8a5312e
commit 9e03364999
2 changed files with 116 additions and 6 deletions

View File

@@ -155,6 +155,14 @@ class BatchEncoding(UserDict):
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
@property
def is_fast(self):
"""
Indicate if this BatchEncoding was generated from the result of a PreTrainedTokenizerFast
Returns: True if generated from subclasses of PreTrainedTokenizerFast, else otherwise
"""
return self._encodings is not None
def __getitem__(self, item: Union[int, str]) -> EncodingFast:
""" If the key is a string, get the value of the dict associated to `key` ('input_ids', 'attention_mask'...)
If the key is an integer, get the EncodingFast for batch item with index `key`
@@ -175,6 +183,16 @@ class BatchEncoding(UserDict):
except KeyError:
raise AttributeError
def __getstate__(self):
return {"data": self.data, "encodings": self._encodings}
def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
if "encodings" in state:
self._encodings = state["encodings"]
def keys(self):
return self.data.keys()
@@ -197,7 +215,7 @@ class BatchEncoding(UserDict):
"""
return self._encodings
def tokens(self, batch_index: int = 0) -> List[int]:
def tokens(self, batch_index: int = 0) -> List[str]:
if not self._encodings:
raise ValueError("tokens() is not available when using Python based tokenizers")
return self._encodings[batch_index].tokens