From d737947725afd06fc3fd0e57939266eba2709bfd Mon Sep 17 00:00:00 2001 From: maru0kun <53220859+maru0kun@users.noreply.github.com> Date: Thu, 5 Sep 2019 19:24:57 +0900 Subject: [PATCH 1/2] Fix typo --- pytorch_transformers/tokenization_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 53b8d245b8..9bb69eb703 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -174,7 +174,7 @@ class PreTrainedTokenizer(object): return self.convert_tokens_to_ids(self._eos_token) @property - def unk_token_is(self): + def unk_token_id(self): """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ if self._unk_token is None: logger.error("Using unk_token, but it is not set yet.") From 5c6cac102b3347960684356d253bb97b4ef2da75 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 5 Sep 2019 21:31:29 +0200 Subject: [PATCH 2/2] adding test for common properties and cleaning up a bit base class --- .../tests/tokenization_tests_commons.py | 16 ++++++++ pytorch_transformers/tokenization_utils.py | 38 +++++++------------ 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 65f45c496c..3da0494ac4 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -55,6 +55,22 @@ class CommonTestCases: def get_input_output_texts(self): raise NotImplementedError + def test_tokenizers_common_properties(self): + tokenizer = self.get_tokenizer() + attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token", + "pad_token", "cls_token", "mask_token"] + for attr in attributes_list: + self.assertTrue(hasattr(tokenizer, attr)) + self.assertTrue(hasattr(tokenizer, attr + "_id")) + + self.assertTrue(hasattr(tokenizer, "additional_special_tokens")) + self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids')) + + attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", + "added_tokens_decoder"] + for attr in attributes_list: + self.assertTrue(hasattr(tokenizer, attr)) + def test_save_and_load_tokenizer(self): # safety check on max_len default value so we are sure the test works tokenizer = self.get_tokenizer() diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 9bb69eb703..1e2cd59648 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -162,58 +162,42 @@ class PreTrainedTokenizer(object): @property def bos_token_id(self): """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ - if self._bos_token is None: - logger.error("Using bos_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._bos_token) + return self.convert_tokens_to_ids(self.bos_token) @property def eos_token_id(self): """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ - if self._eos_token is None: - logger.error("Using eos_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._eos_token) + return self.convert_tokens_to_ids(self.eos_token) @property def unk_token_id(self): """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ - if self._unk_token is None: - logger.error("Using unk_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._unk_token) + return self.convert_tokens_to_ids(self.unk_token) @property def sep_token_id(self): """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ - if self._sep_token is None: - logger.error("Using sep_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._sep_token) + return self.convert_tokens_to_ids(self.sep_token) @property def pad_token_id(self): """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """ - if self._pad_token is None: - logger.error("Using pad_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._pad_token) + return self.convert_tokens_to_ids(self.pad_token) @property def cls_token_id(self): """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ - if self._cls_token is None: - logger.error("Using cls_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._cls_token) + return self.convert_tokens_to_ids(self.cls_token) @property def mask_token_id(self): """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ - if self._mask_token is None: - logger.error("Using mask_token, but it is not set yet.") - return self.convert_tokens_to_ids(self._mask_token) + return self.convert_tokens_to_ids(self.mask_token) @property def additional_special_tokens_ids(self): """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ - if self._additional_special_tokens is None: - logger.error("Using additional_special_tokens, but it is not set yet.") - return self.convert_tokens_to_ids(self._additional_special_tokens) + return self.convert_tokens_to_ids(self.additional_special_tokens) def __init__(self, max_len=None, **kwargs): self._bos_token = None @@ -653,6 +637,9 @@ class PreTrainedTokenizer(object): """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id (resp. a sequence of ids), using the vocabulary. """ + if tokens is None: + return None + if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): return self._convert_token_to_id_with_added_voc(tokens) @@ -666,6 +653,9 @@ class PreTrainedTokenizer(object): return ids def _convert_token_to_id_with_added_voc(self, token): + if token is None: + return None + if token in self.added_tokens_encoder: return self.added_tokens_encoder[token] return self._convert_token_to_id(token)