From 47d6853439318f1be596219e270bee4e3819dfbb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 23 Aug 2019 17:31:11 +0200 Subject: [PATCH] adding max_lengths for single sentences and sentences pairs --- pytorch_transformers/tokenization_bert.py | 8 ++++++++ pytorch_transformers/tokenization_roberta.py | 8 ++++++++ pytorch_transformers/tokenization_utils.py | 8 ++++++++ pytorch_transformers/tokenization_xlm.py | 8 ++++++++ pytorch_transformers/tokenization_xlnet.py | 8 ++++++++ 5 files changed, 40 insertions(+) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 04f35aa466..8ea71ba92b 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -139,6 +139,14 @@ class BertTokenizer(PreTrainedTokenizer): tokenize_chinese_chars=tokenize_chinese_chars) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) + @property + def max_len_single_sentence(self): + return self.max_len - 2 # take into account special tokens + + @property + def max_len_sentences_pair(self): + return self.max_len - 3 # take into account special tokens + @property def vocab_size(self): return len(self.vocab) diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index edf4717c89..44047e636f 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -160,6 +160,14 @@ class RobertaTokenizer(PreTrainedTokenizer): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text + @property + def max_len_single_sentence(self): + return self.max_len - 2 # take into account special tokens + + @property + def max_len_sentences_pair(self): + return self.max_len - 4 # take into account special tokens + def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence for sequence classification tasks. diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index d2855e0922..a128c3fd72 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -67,6 +67,14 @@ class PreTrainedTokenizer(object): "pad_token", "cls_token", "mask_token", "additional_special_tokens"] + @property + def max_len_single_sentence(self): + return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens + + @property + def max_len_sentences_pair(self): + return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens + @property def bos_token(self): """ Beginning of sentence token (string). Log an error if used while not having been set. """ diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 2d2f3a8cd4..b544923e35 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -215,6 +215,14 @@ class XLMTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace('', ' ').strip() return out_string + @property + def max_len_single_sentence(self): + return self.max_len - 2 # take into account special tokens + + @property + def max_len_sentences_pair(self): + return self.max_len - 3 # take into account special tokens + def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence for sequence classification tasks. diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index 371b3c9407..a282d67904 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -177,6 +177,14 @@ class XLNetTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string + @property + def max_len_single_sentence(self): + return self.max_len - 2 # take into account special tokens + + @property + def max_len_sentences_pair(self): + return self.max_len - 3 # take into account special tokens + def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence pair for sequence classification tasks.