From 3bcbebd440c220adbaab657f2d13dac7c89f6453 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 23 Aug 2019 22:07:26 +0200 Subject: [PATCH] max_len_single_sentence & max_len_sentences_pair as attributes so they can be modified --- pytorch_transformers/tokenization_bert.py | 11 +++-------- pytorch_transformers/tokenization_gpt2.py | 2 ++ pytorch_transformers/tokenization_openai.py | 3 +++ pytorch_transformers/tokenization_roberta.py | 11 +++-------- pytorch_transformers/tokenization_transfo_xl.py | 4 ++++ pytorch_transformers/tokenization_utils.py | 11 +++-------- pytorch_transformers/tokenization_xlm.py | 12 ++++-------- pytorch_transformers/tokenization_xlnet.py | 12 ++++-------- 8 files changed, 26 insertions(+), 40 deletions(-) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 8ea71ba92b..92f027038d 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -125,6 +125,9 @@ class BertTokenizer(PreTrainedTokenizer): super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, **kwargs) + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " @@ -139,14 +142,6 @@ class BertTokenizer(PreTrainedTokenizer): tokenize_chinese_chars=tokenize_chinese_chars) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) - @property - def max_len_single_sentence(self): - return self.max_len - 2 # take into account special tokens - - @property - def max_len_sentences_pair(self): - return self.max_len - 3 # take into account special tokens - @property def vocab_size(self): return len(self.vocab) diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index e67f25ff59..13806a3708 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -108,6 +108,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} diff --git a/pytorch_transformers/tokenization_openai.py b/pytorch_transformers/tokenization_openai.py index 51b418ebd3..0efbdb37c0 100644 --- a/pytorch_transformers/tokenization_openai.py +++ b/pytorch_transformers/tokenization_openai.py @@ -87,6 +87,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) + self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens + try: import ftfy from spacy.lang.en import English diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index 44047e636f..e8ab29238e 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -77,6 +77,9 @@ class RobertaTokenizer(PreTrainedTokenizer): sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, **kwargs) + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens + self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -160,14 +163,6 @@ class RobertaTokenizer(PreTrainedTokenizer): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text - @property - def max_len_single_sentence(self): - return self.max_len - 2 # take into account special tokens - - @property - def max_len_sentences_pair(self): - return self.max_len - 4 # take into account special tokens - def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence for sequence classification tasks. diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index 992dff80d5..c603ba695c 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -73,6 +73,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer): super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs) + + self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens + self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens + if never_split is None: never_split = self.all_special_tokens if special is None: diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index a128c3fd72..2fb7f87e9c 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -67,14 +67,6 @@ class PreTrainedTokenizer(object): "pad_token", "cls_token", "mask_token", "additional_special_tokens"] - @property - def max_len_single_sentence(self): - return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens - - @property - def max_len_sentences_pair(self): - return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens - @property def bos_token(self): """ Beginning of sentence token (string). Log an error if used while not having been set. """ @@ -174,6 +166,9 @@ class PreTrainedTokenizer(object): self._additional_special_tokens = [] self.max_len = max_len if max_len is not None else int(1e12) + self.max_len_single_sentence = self.max_len + self.max_len_sentences_pair = self.max_len + self.added_tokens_encoder = {} self.added_tokens_decoder = {} diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index b544923e35..2b930458bb 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -122,6 +122,10 @@ class XLMTokenizer(PreTrainedTokenizer): cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs) + + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + try: import ftfy from spacy.lang.en import English @@ -215,14 +219,6 @@ class XLMTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace('', ' ').strip() return out_string - @property - def max_len_single_sentence(self): - return self.max_len - 2 # take into account special tokens - - @property - def max_len_sentences_pair(self): - return self.max_len - 3 # take into account special tokens - def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence for sequence classification tasks. diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index a282d67904..ac7231bb68 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -71,6 +71,10 @@ class XLNetTokenizer(PreTrainedTokenizer): pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, additional_special_tokens= additional_special_tokens, **kwargs) + + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + try: import sentencepiece as spm except ImportError: @@ -177,14 +181,6 @@ class XLNetTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string - @property - def max_len_single_sentence(self): - return self.max_len - 2 # take into account special tokens - - @property - def max_len_sentences_pair(self): - return self.max_len - 3 # take into account special tokens - def add_special_tokens_single_sentence(self, token_ids): """ Adds special tokens to a sequence pair for sequence classification tasks.