adding test for common properties and cleaning up a bit base class
This commit is contained in:
@@ -55,6 +55,22 @@ class CommonTestCases:
|
|||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_tokenizers_common_properties(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||||
|
"pad_token", "cls_token", "mask_token"]
|
||||||
|
for attr in attributes_list:
|
||||||
|
self.assertTrue(hasattr(tokenizer, attr))
|
||||||
|
self.assertTrue(hasattr(tokenizer, attr + "_id"))
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
|
||||||
|
self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids'))
|
||||||
|
|
||||||
|
attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder",
|
||||||
|
"added_tokens_decoder"]
|
||||||
|
for attr in attributes_list:
|
||||||
|
self.assertTrue(hasattr(tokenizer, attr))
|
||||||
|
|
||||||
def test_save_and_load_tokenizer(self):
|
def test_save_and_load_tokenizer(self):
|
||||||
# safety check on max_len default value so we are sure the test works
|
# safety check on max_len default value so we are sure the test works
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|||||||
@@ -162,58 +162,42 @@ class PreTrainedTokenizer(object):
|
|||||||
@property
|
@property
|
||||||
def bos_token_id(self):
|
def bos_token_id(self):
|
||||||
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
if self._bos_token is None:
|
return self.convert_tokens_to_ids(self.bos_token)
|
||||||
logger.error("Using bos_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._bos_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_token_id(self):
|
def eos_token_id(self):
|
||||||
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
|
||||||
if self._eos_token is None:
|
return self.convert_tokens_to_ids(self.eos_token)
|
||||||
logger.error("Using eos_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._eos_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unk_token_id(self):
|
def unk_token_id(self):
|
||||||
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
|
||||||
if self._unk_token is None:
|
return self.convert_tokens_to_ids(self.unk_token)
|
||||||
logger.error("Using unk_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._unk_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sep_token_id(self):
|
def sep_token_id(self):
|
||||||
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
|
||||||
if self._sep_token is None:
|
return self.convert_tokens_to_ids(self.sep_token)
|
||||||
logger.error("Using sep_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._sep_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token_id(self):
|
def pad_token_id(self):
|
||||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
||||||
if self._pad_token is None:
|
return self.convert_tokens_to_ids(self.pad_token)
|
||||||
logger.error("Using pad_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._pad_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cls_token_id(self):
|
def cls_token_id(self):
|
||||||
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
|
||||||
if self._cls_token is None:
|
return self.convert_tokens_to_ids(self.cls_token)
|
||||||
logger.error("Using cls_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._cls_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mask_token_id(self):
|
def mask_token_id(self):
|
||||||
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
|
||||||
if self._mask_token is None:
|
return self.convert_tokens_to_ids(self.mask_token)
|
||||||
logger.error("Using mask_token, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._mask_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def additional_special_tokens_ids(self):
|
def additional_special_tokens_ids(self):
|
||||||
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
|
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
|
||||||
if self._additional_special_tokens is None:
|
return self.convert_tokens_to_ids(self.additional_special_tokens)
|
||||||
logger.error("Using additional_special_tokens, but it is not set yet.")
|
|
||||||
return self.convert_tokens_to_ids(self._additional_special_tokens)
|
|
||||||
|
|
||||||
def __init__(self, max_len=None, **kwargs):
|
def __init__(self, max_len=None, **kwargs):
|
||||||
self._bos_token = None
|
self._bos_token = None
|
||||||
@@ -653,6 +637,9 @@ class PreTrainedTokenizer(object):
|
|||||||
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
|
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
|
||||||
(resp. a sequence of ids), using the vocabulary.
|
(resp. a sequence of ids), using the vocabulary.
|
||||||
"""
|
"""
|
||||||
|
if tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||||
return self._convert_token_to_id_with_added_voc(tokens)
|
return self._convert_token_to_id_with_added_voc(tokens)
|
||||||
|
|
||||||
@@ -666,6 +653,9 @@ class PreTrainedTokenizer(object):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
def _convert_token_to_id_with_added_voc(self, token):
|
def _convert_token_to_id_with_added_voc(self, token):
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
|
||||||
if token in self.added_tokens_encoder:
|
if token in self.added_tokens_encoder:
|
||||||
return self.added_tokens_encoder[token]
|
return self.added_tokens_encoder[token]
|
||||||
return self._convert_token_to_id(token)
|
return self._convert_token_to_id(token)
|
||||||
|
|||||||
Reference in New Issue
Block a user