From 3471ff0d35b0481816dde44a5ecf51da076ee467 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 12:23:30 -0500 Subject: [PATCH 01/11] FastPreTrainedTokenizer --- src/transformers/tokenization_utils.py | 127 +++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 73075521fe..ac5e2d9d07 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1410,3 +1410,130 @@ class PreTrainedTokenizer(object): .replace(" 're", "'re") ) return out_string + +class FastPreTrainedTokenizer(PreTrainedTokenizer): + def __init__(self, **kwargs): + super(FastPreTrainedTokenizer, self).__init__(**kwargs) + + @property + def tokenizer(self): + if self._tokenizer is None: + raise NotImplementedError + return self._tokenizer + + @property + def decoder(self): + if self._decoder is None: + raise NotImplementedError + return self._decoder + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size(False) + + def __len__(self): + return self.tokenizer.get_vocab_size(True) + + def _update_special_tokens(self): + self.tokenizer.add_special_tokens(self.all_special_tokens) + + @staticmethod + def _convert_encoding(encoding, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False): + encoding_dict = { + "input_ids": encoding.ids, + } + if return_token_type_ids: + encoding_dict["token_type_ids"] = encoding.type_ids + if return_attention_mask: + encoding_dict["attention_mask"] = encoding.attention_mask + if return_overflowing_tokens: + overflowing = encoding.overflowing + encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else [] + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask + + # Prepare inputs as tensors if asked + if return_tensors == 'tf' and is_tf_available(): + encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]]) + encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]]) + + if "attention_mask" in encoding_dict: + encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]]) + + elif return_tensors == 'pt' and is_torch_available(): + encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]]) + encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]]) + + if "attention_mask" in encoding_dict: + encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]]) + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors)) + + return encoding_dict + + def encode_plus(self, + text, + text_pair=None, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + **kwargs): + encoding = self.tokenizer.encode(text, text_pair) + return self._convert_encoding(encoding, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask) + + def tokenize(self, text): + return self.tokenizer.encode(text).tokens + + def _convert_token_to_id_with_added_voc(self, token): + return self.tokenizer.token_to_id(token) + + def _convert_id_to_token(self, index): + return self.tokenizer.id_to_token(int(index)) + + def convert_tokens_to_string(self, tokens): + return self.decoder.decode(tokens) + + def add_tokens(self, new_tokens): + self.tokenizer.add_tokens(new_tokens) + + def encode_batch(self, texts, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False): + return [self._convert_encoding(encoding, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask) + for encoding in self.tokenizer.encode_batch(texts)] + + def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + text = self.tokenizer.decode(token_ids, skip_special_tokens) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True): + return [self.clean_up_tokenization(text) + if clear_up_tokenization_spaces else text + for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)] \ No newline at end of file From 041eac2d6d5f7bc04ec35772ff2ae738b96637ed Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 12:24:14 -0500 Subject: [PATCH 02/11] GPT2TokenizerFast --- src/transformers/__init__.py | 2 +- src/transformers/tokenization_gpt2.py | 35 ++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 84a308d1c1..e7cf22321b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -108,7 +108,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer from .tokenization_distilbert import DistilBertTokenizer -from .tokenization_gpt2 import GPT2Tokenizer +from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_roberta import RobertaTokenizer from .tokenization_t5 import T5Tokenizer diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 7687fd09fc..5b5be1253e 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -22,7 +22,7 @@ from functools import lru_cache import regex as re -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer logger = logging.getLogger(__name__) @@ -246,3 +246,36 @@ class GPT2Tokenizer(PreTrainedTokenizer): index += 1 return vocab_file, merge_file + +class GPT2TokenizerFast(FastPreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>", + eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False, + max_length=None, stride=0, truncation_strategy='longest_first', **kwargs): + + try: + from tokenizers import Tokenizer, models, pre_tokenizers, decoders + + super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + + self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file)) + self._update_special_tokens() + self._tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_space)) + self._tokenizer.with_decoder(decoders.ByteLevel.new()) + if max_length: + self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_padding( + max_length if pad_to_max_length else None, + self.padding_side, + self.pad_token_id if self.pad_token_id is not None else 0, + self.pad_token_type_id, + self.pad_token if self.pad_token is not None else "" + ) + self._decoder = decoders.ByteLevel.new() + + except (AttributeError, ImportError) as e: + logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") + raise e \ No newline at end of file From 951ae99bea3bd8a37397228b6d1f57257a71a6cf Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 12:24:24 -0500 Subject: [PATCH 03/11] BertTokenizerFast --- src/transformers/__init__.py | 2 +- src/transformers/tokenization_bert.py | 53 ++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e7cf22321b..e305c8b15b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -103,7 +103,7 @@ from .pipelines import ( ) from .tokenization_albert import AlbertTokenizer from .tokenization_auto import AutoTokenizer -from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer +from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index e2d3980c47..e2ba03594b 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -20,7 +20,7 @@ import logging import os import unicodedata -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer logger = logging.getLogger(__name__) @@ -525,3 +525,54 @@ def _is_punctuation(char): if cat.startswith("P"): return True return False + +class BertTokenizerFast(FastPreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, + unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", + mask_token="[MASK]", tokenize_chinese_chars=True, + max_length=None, pad_to_max_length=False, stride=0, + truncation_strategy='longest_first', add_special_tokens=True, **kwargs): + + try: + from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors + super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token, + pad_token=pad_token, cls_token=cls_token, + mask_token=mask_token, **kwargs) + + self._tokenizer = Tokenizer(models.WordPiece.from_files( + vocab_file, + unk_token=unk_token + )) + self._update_special_tokens() + self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new( + do_basic_tokenize=do_basic_tokenize, + do_lower_case=do_lower_case, + tokenize_chinese_chars=tokenize_chinese_chars, + never_split=never_split if never_split is not None else [], + )) + self._tokenizer.with_decoder(decoders.WordPiece.new()) + + if add_special_tokens: + self._tokenizer.with_post_processor(processors.BertProcessing.new( + (sep_token, self._tokenizer.token_to_id(sep_token)), + (cls_token, self._tokenizer.token_to_id(cls_token)), + )) + if max_length is not None: + self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_padding( + max_length if pad_to_max_length else None, + self.padding_side, + self.pad_token_id, + self.pad_token_type_id, + self.pad_token + ) + self._decoder = decoders.WordPiece.new() + + except (AttributeError, ImportError) as e: + logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") + raise e \ No newline at end of file From 31c56f2e0b908a7f7f5669b3d535c65f156f6556 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 12:43:27 -0500 Subject: [PATCH 04/11] Fix style --- src/transformers/tokenization_bert.py | 72 ++++++++++++------- src/transformers/tokenization_gpt2.py | 28 ++++++-- src/transformers/tokenization_utils.py | 97 +++++++++++++++----------- 3 files changed, 125 insertions(+), 72 deletions(-) diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index e2ba03594b..05aa42cadc 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -20,7 +20,7 @@ import logging import os import unicodedata -from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer +from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -526,42 +526,64 @@ def _is_punctuation(char): return True return False + class BertTokenizerFast(FastPreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, - unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", - mask_token="[MASK]", tokenize_chinese_chars=True, - max_length=None, pad_to_max_length=False, stride=0, - truncation_strategy='longest_first', add_special_tokens=True, **kwargs): + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + max_length=None, + pad_to_max_length=False, + stride=0, + truncation_strategy="longest_first", + add_special_tokens=True, + **kwargs + ): try: from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors - super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token, - pad_token=pad_token, cls_token=cls_token, - mask_token=mask_token, **kwargs) - self._tokenizer = Tokenizer(models.WordPiece.from_files( - vocab_file, - unk_token=unk_token - )) + super(BertTokenizerFast, self).__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs + ) + + self._tokenizer = Tokenizer(models.WordPiece.from_files(vocab_file, unk_token=unk_token)) self._update_special_tokens() - self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new( - do_basic_tokenize=do_basic_tokenize, - do_lower_case=do_lower_case, - tokenize_chinese_chars=tokenize_chinese_chars, - never_split=never_split if never_split is not None else [], - )) + self._tokenizer.with_pre_tokenizer( + pre_tokenizers.BertPreTokenizer.new( + do_basic_tokenize=do_basic_tokenize, + do_lower_case=do_lower_case, + tokenize_chinese_chars=tokenize_chinese_chars, + never_split=never_split if never_split is not None else [], + ) + ) self._tokenizer.with_decoder(decoders.WordPiece.new()) if add_special_tokens: - self._tokenizer.with_post_processor(processors.BertProcessing.new( - (sep_token, self._tokenizer.token_to_id(sep_token)), - (cls_token, self._tokenizer.token_to_id(cls_token)), - )) + self._tokenizer.with_post_processor( + processors.BertProcessing.new( + (sep_token, self._tokenizer.token_to_id(sep_token)), + (cls_token, self._tokenizer.token_to_id(cls_token)), + ) + ) if max_length is not None: self._tokenizer.with_truncation(max_length, stride, truncation_strategy) self._tokenizer.with_padding( @@ -569,10 +591,10 @@ class BertTokenizerFast(FastPreTrainedTokenizer): self.padding_side, self.pad_token_id, self.pad_token_type_id, - self.pad_token + self.pad_token, ) self._decoder = decoders.WordPiece.new() except (AttributeError, ImportError) as e: logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") - raise e \ No newline at end of file + raise e diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 5b5be1253e..9514975079 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -22,7 +22,7 @@ from functools import lru_cache import regex as re -from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer +from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -247,19 +247,33 @@ class GPT2Tokenizer(PreTrainedTokenizer): return vocab_file, merge_file + class GPT2TokenizerFast(FastPreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>", - eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False, - max_length=None, stride=0, truncation_strategy='longest_first', **kwargs): + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + pad_to_max_length=False, + add_prefix_space=False, + max_length=None, + stride=0, + truncation_strategy="longest_first", + **kwargs + ): try: from tokenizers import Tokenizer, models, pre_tokenizers, decoders - super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + super(GPT2TokenizerFast, self).__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file)) self._update_special_tokens() @@ -272,10 +286,10 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer): self.padding_side, self.pad_token_id if self.pad_token_id is not None else 0, self.pad_token_type_id, - self.pad_token if self.pad_token is not None else "" + self.pad_token if self.pad_token is not None else "", ) self._decoder = decoders.ByteLevel.new() except (AttributeError, ImportError) as e: logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") - raise e \ No newline at end of file + raise e diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index ac5e2d9d07..57e2b909f7 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1411,6 +1411,7 @@ class PreTrainedTokenizer(object): ) return out_string + class FastPreTrainedTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): super(FastPreTrainedTokenizer, self).__init__(**kwargs) @@ -1438,12 +1439,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): self.tokenizer.add_special_tokens(self.all_special_tokens) @staticmethod - def _convert_encoding(encoding, - return_tensors=None, - return_token_type_ids=True, - return_attention_mask=True, - return_overflowing_tokens=False, - return_special_tokens_mask=False): + def _convert_encoding( + encoding, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + ): encoding_dict = { "input_ids": encoding.ids, } @@ -1458,14 +1461,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask # Prepare inputs as tensors if asked - if return_tensors == 'tf' and is_tf_available(): + if return_tensors == "tf" and is_tf_available(): encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]]) encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]]) if "attention_mask" in encoding_dict: encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]]) - elif return_tensors == 'pt' and is_torch_available(): + elif return_tensors == "pt" and is_torch_available(): encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]]) encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]]) @@ -1474,26 +1477,32 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): elif return_tensors is not None: logger.warning( "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( - return_tensors)) + return_tensors + ) + ) return encoding_dict - def encode_plus(self, - text, - text_pair=None, - return_tensors=None, - return_token_type_ids=True, - return_attention_mask=True, - return_overflowing_tokens=False, - return_special_tokens_mask=False, - **kwargs): + def encode_plus( + self, + text, + text_pair=None, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + **kwargs + ): encoding = self.tokenizer.encode(text, text_pair) - return self._convert_encoding(encoding, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask) + return self._convert_encoding( + encoding, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + ) def tokenize(self, text): return self.tokenizer.encode(text).tokens @@ -1510,19 +1519,26 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): def add_tokens(self, new_tokens): self.tokenizer.add_tokens(new_tokens) - def encode_batch(self, texts, - return_tensors=None, - return_token_type_ids=True, - return_attention_mask=True, - return_overflowing_tokens=False, - return_special_tokens_mask=False): - return [self._convert_encoding(encoding, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask) - for encoding in self.tokenizer.encode_batch(texts)] + def encode_batch( + self, + texts, + return_tensors=None, + return_token_type_ids=True, + return_attention_mask=True, + return_overflowing_tokens=False, + return_special_tokens_mask=False, + ): + return [ + self._convert_encoding( + encoding, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + ) + for encoding in self.tokenizer.encode_batch(texts) + ] def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): text = self.tokenizer.decode(token_ids, skip_special_tokens) @@ -1534,6 +1550,7 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): return text def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True): - return [self.clean_up_tokenization(text) - if clear_up_tokenization_spaces else text - for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)] \ No newline at end of file + return [ + self.clean_up_tokenization(text) if clear_up_tokenization_spaces else text + for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens) + ] From 2818e505694ee4b5b02a9c7b51faf4dd137728d4 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 13:29:01 -0500 Subject: [PATCH 05/11] Add tests for fast tokenizers --- tests/test_tokenization_bert.py | 27 ++++++++++++++++++++++ tests/test_tokenization_common.py | 4 ++++ tests/test_tokenization_gpt2.py | 37 ++++++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py index 24e008d734..7af6cbee73 100644 --- a/tests/test_tokenization_bert.py +++ b/tests/test_tokenization_bert.py @@ -21,6 +21,7 @@ from transformers.tokenization_bert import ( VOCAB_FILES_NAMES, BasicTokenizer, BertTokenizer, + BertTokenizerFast, WordpieceTokenizer, _is_control, _is_punctuation, @@ -34,6 +35,7 @@ from .utils import slow class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = BertTokenizer + test_rust_tokenizer = True def setUp(self): super(BertTokenizationTest, self).setUp() @@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def get_tokenizer(self, **kwargs): return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs): + return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + def get_input_output_texts(self): input_text = "UNwant\u00E9d,running" output_text = "unwanted, running" @@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False) + + sequence = u"UNwant\u00E9d,running" + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 035a0dc27f..1fa965ea40 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -23,6 +23,7 @@ import tempfile class TokenizerTesterMixin: tokenizer_class = None + test_rust_tokenizer = False def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -33,6 +34,9 @@ class TokenizerTesterMixin: def get_tokenizer(self, **kwargs): raise NotImplementedError + def get_rust_tokenizer(self, **kwargs): + raise NotImplementedError + def get_input_output_texts(self): raise NotImplementedError diff --git a/tests/test_tokenization_gpt2.py b/tests/test_tokenization_gpt2.py index 7353d55178..fdd8026a8f 100644 --- a/tests/test_tokenization_gpt2.py +++ b/tests/test_tokenization_gpt2.py @@ -18,7 +18,7 @@ import json import os import unittest -from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer +from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast from .test_tokenization_common import TokenizerTesterMixin @@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = GPT2Tokenizer + test_rust_tokenizer = True def setUp(self): super(GPT2TokenizationTest, self).setUp() @@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): kwargs.update(self.special_tokens_map) return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + def get_input_output_texts(self): input_text = "lower newer" output_text = "lower newer" @@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): input_tokens = tokens + [tokenizer.unk_token] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True) + + sequence = u"lower newer" + + # Testing tokenization + tokens = tokenizer.tokenize(sequence, add_prefix_space=True) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + # Testing conversion to ids without special tokens + ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + # Testing conversion to ids with special tokens + rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True) + ids = tokenizer.encode(sequence, add_prefix_space=True) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + # Testing the unknown token + input_tokens = tokens + [rust_tokenizer.unk_token] + input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] + self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) From 734d29b03d298936db7d7a41824ce15065bdf16c Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 13:32:41 -0500 Subject: [PATCH 06/11] tokenizers is now a real dependency --- setup.py | 1 + src/transformers/tokenization_bert.py | 76 ++++++++++++--------------- src/transformers/tokenization_gpt2.py | 43 +++++++-------- 3 files changed, 54 insertions(+), 66 deletions(-) diff --git a/setup.py b/setup.py index 558a38ea8b..52fef444f7 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ setup( packages=find_packages("src"), install_requires=[ "numpy", + "tokenizers", # accessing files from S3 directly "boto3", # filesystem locks e.g. to prevent parallel downloads diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 05aa42cadc..cb78f03df7 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -20,6 +20,8 @@ import logging import os import unicodedata +import tokenizers as tk + from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer @@ -552,49 +554,41 @@ class BertTokenizerFast(FastPreTrainedTokenizer): add_special_tokens=True, **kwargs ): + super(BertTokenizerFast, self).__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs + ) - try: - from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors - - super(BertTokenizerFast, self).__init__( - unk_token=unk_token, - sep_token=sep_token, - pad_token=pad_token, - cls_token=cls_token, - mask_token=mask_token, - **kwargs + self._tokenizer = tk.Tokenizer(tk.models.WordPiece.from_files(vocab_file, unk_token=unk_token)) + self._update_special_tokens() + self._tokenizer.with_pre_tokenizer( + tk.pre_tokenizers.BertPreTokenizer.new( + do_basic_tokenize=do_basic_tokenize, + do_lower_case=do_lower_case, + tokenize_chinese_chars=tokenize_chinese_chars, + never_split=never_split if never_split is not None else [], ) + ) + self._tokenizer.with_decoder(tk.decoders.WordPiece.new()) - self._tokenizer = Tokenizer(models.WordPiece.from_files(vocab_file, unk_token=unk_token)) - self._update_special_tokens() - self._tokenizer.with_pre_tokenizer( - pre_tokenizers.BertPreTokenizer.new( - do_basic_tokenize=do_basic_tokenize, - do_lower_case=do_lower_case, - tokenize_chinese_chars=tokenize_chinese_chars, - never_split=never_split if never_split is not None else [], + if add_special_tokens: + self._tokenizer.with_post_processor( + tk.processors.BertProcessing.new( + (sep_token, self._tokenizer.token_to_id(sep_token)), + (cls_token, self._tokenizer.token_to_id(cls_token)), ) ) - self._tokenizer.with_decoder(decoders.WordPiece.new()) - - if add_special_tokens: - self._tokenizer.with_post_processor( - processors.BertProcessing.new( - (sep_token, self._tokenizer.token_to_id(sep_token)), - (cls_token, self._tokenizer.token_to_id(cls_token)), - ) - ) - if max_length is not None: - self._tokenizer.with_truncation(max_length, stride, truncation_strategy) - self._tokenizer.with_padding( - max_length if pad_to_max_length else None, - self.padding_side, - self.pad_token_id, - self.pad_token_type_id, - self.pad_token, - ) - self._decoder = decoders.WordPiece.new() - - except (AttributeError, ImportError) as e: - logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") - raise e + if max_length is not None: + self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_padding( + max_length if pad_to_max_length else None, + self.padding_side, + self.pad_token_id, + self.pad_token_type_id, + self.pad_token, + ) + self._decoder = tk.decoders.WordPiece.new() diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 9514975079..bba5eeb762 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -21,6 +21,7 @@ import os from functools import lru_cache import regex as re +import tokenizers as tk from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer @@ -267,29 +268,21 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer): truncation_strategy="longest_first", **kwargs ): + super(GPT2TokenizerFast, self).__init__( + bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs + ) - try: - from tokenizers import Tokenizer, models, pre_tokenizers, decoders - - super(GPT2TokenizerFast, self).__init__( - bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs - ) - - self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file)) - self._update_special_tokens() - self._tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_space)) - self._tokenizer.with_decoder(decoders.ByteLevel.new()) - if max_length: - self._tokenizer.with_truncation(max_length, stride, truncation_strategy) - self._tokenizer.with_padding( - max_length if pad_to_max_length else None, - self.padding_side, - self.pad_token_id if self.pad_token_id is not None else 0, - self.pad_token_type_id, - self.pad_token if self.pad_token is not None else "", - ) - self._decoder = decoders.ByteLevel.new() - - except (AttributeError, ImportError) as e: - logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") - raise e + self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file)) + self._update_special_tokens() + self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space)) + self._tokenizer.with_decoder(tk.decoders.ByteLevel.new()) + if max_length: + self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_padding( + max_length if pad_to_max_length else None, + self.padding_side, + self.pad_token_id if self.pad_token_id is not None else 0, + self.pad_token_type_id, + self.pad_token if self.pad_token is not None else "", + ) + self._decoder = tk.decoders.ByteLevel.new() From 1f82a5d910f866b7eab7b099502718e8380499fb Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 14:37:55 -0500 Subject: [PATCH 07/11] Update for changes in tokenizers API --- setup.py | 2 +- src/transformers/tokenization_bert.py | 14 ++++++++------ src/transformers/tokenization_gpt2.py | 16 +++++++++------- src/transformers/tokenization_utils.py | 4 ++-- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 52fef444f7..2753236f9d 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ setup( packages=find_packages("src"), install_requires=[ "numpy", - "tokenizers", + "tokenizers == 0.0.10", # accessing files from S3 directly "boto3", # filesystem locks e.g. to prevent parallel downloads diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index cb78f03df7..5362806bdc 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -583,12 +583,14 @@ class BertTokenizerFast(FastPreTrainedTokenizer): ) ) if max_length is not None: - self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_truncation(max_length, + stride=stride, + strategy=truncation_strategy) self._tokenizer.with_padding( - max_length if pad_to_max_length else None, - self.padding_side, - self.pad_token_id, - self.pad_token_type_id, - self.pad_token, + max_length=max_length if pad_to_max_length else None, + direction=self.padding_side, + pad_id=self.pad_token_id, + pad_type_id=self.pad_token_type_id, + pad_token=self.pad_token, ) self._decoder = tk.decoders.WordPiece.new() diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index bba5eeb762..4cc1d4708e 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -274,15 +274,17 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer): self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file)) self._update_special_tokens() - self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space)) + self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space)) self._tokenizer.with_decoder(tk.decoders.ByteLevel.new()) if max_length: - self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_truncation(max_length, + stride=stride, + strategy=truncation_strategy) self._tokenizer.with_padding( - max_length if pad_to_max_length else None, - self.padding_side, - self.pad_token_id if self.pad_token_id is not None else 0, - self.pad_token_type_id, - self.pad_token if self.pad_token is not None else "", + max_length=max_length if pad_to_max_length else None, + direction=self.padding_side, + pad_id=self.pad_token_id if self.pad_token_id is not None else 0, + pad_type_id=self.pad_token_type_id, + pad_token=self.pad_token if self.pad_token is not None else "", ) self._decoder = tk.decoders.ByteLevel.new() diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 57e2b909f7..c52ba6ff79 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1430,10 +1430,10 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): @property def vocab_size(self): - return self.tokenizer.get_vocab_size(False) + return self.tokenizer.get_vocab_size(with_added_tokens=False) def __len__(self): - return self.tokenizer.get_vocab_size(True) + return self.tokenizer.get_vocab_size(with_added_tokens=True) def _update_special_tokens(self): self.tokenizer.add_special_tokens(self.all_special_tokens) From 7ead04ce1453940914abfdd24c826814c117f49e Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 14:39:39 -0500 Subject: [PATCH 08/11] FastPreTrainedTokenizer => PreTrainedTokenizerFast --- src/transformers/tokenization_bert.py | 4 ++-- src/transformers/tokenization_gpt2.py | 4 ++-- src/transformers/tokenization_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 5362806bdc..4d41654037 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -22,7 +22,7 @@ import unicodedata import tokenizers as tk -from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizerFast, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -529,7 +529,7 @@ def _is_punctuation(char): return False -class BertTokenizerFast(FastPreTrainedTokenizer): +class BertTokenizerFast(PreTrainedTokenizerFast): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 4cc1d4708e..1cdc263e25 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -23,7 +23,7 @@ from functools import lru_cache import regex as re import tokenizers as tk -from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizerFast, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -249,7 +249,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): return vocab_file, merge_file -class GPT2TokenizerFast(FastPreTrainedTokenizer): +class GPT2TokenizerFast(PreTrainedTokenizerFast): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index c52ba6ff79..535322be10 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1412,9 +1412,9 @@ class PreTrainedTokenizer(object): return out_string -class FastPreTrainedTokenizer(PreTrainedTokenizer): +class PreTrainedTokenizerFast(PreTrainedTokenizer): def __init__(self, **kwargs): - super(FastPreTrainedTokenizer, self).__init__(**kwargs) + super(PreTrainedTokenizerFast, self).__init__(**kwargs) @property def tokenizer(self): From 835b76a46f37ecaa00eebf6d7f190297b44d6103 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 14:42:55 -0500 Subject: [PATCH 09/11] Handle unk_token As we discussed, this is handled here directly cc @thomwolf --- src/transformers/tokenization_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 535322be10..210e47e752 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1508,7 +1508,10 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): return self.tokenizer.encode(text).tokens def _convert_token_to_id_with_added_voc(self, token): - return self.tokenizer.token_to_id(token) + id = self.tokenizer.token_to_id(token) + if id is None: + return self.unk_token_id + return id def _convert_id_to_token(self, index): return self.tokenizer.id_to_token(int(index)) From 599db139f921f3af535052c860cb685cadae6fcd Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 15:13:30 -0500 Subject: [PATCH 10/11] Code style update --- src/transformers/tokenization_bert.py | 6 ++---- src/transformers/tokenization_gpt2.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 4d41654037..4c2abc7d17 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -22,7 +22,7 @@ import unicodedata import tokenizers as tk -from .tokenization_utils import PreTrainedTokenizerFast, PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast logger = logging.getLogger(__name__) @@ -583,9 +583,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ) ) if max_length is not None: - self._tokenizer.with_truncation(max_length, - stride=stride, - strategy=truncation_strategy) + self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy) self._tokenizer.with_padding( max_length=max_length if pad_to_max_length else None, direction=self.padding_side, diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 1cdc263e25..c99ec08ffc 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -23,7 +23,7 @@ from functools import lru_cache import regex as re import tokenizers as tk -from .tokenization_utils import PreTrainedTokenizerFast, PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast logger = logging.getLogger(__name__) @@ -277,9 +277,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space)) self._tokenizer.with_decoder(tk.decoders.ByteLevel.new()) if max_length: - self._tokenizer.with_truncation(max_length, - stride=stride, - strategy=truncation_strategy) + self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy) self._tokenizer.with_padding( max_length=max_length if pad_to_max_length else None, direction=self.padding_side, From e6ec24fa881446e7c06fd5ab2cbc461899428c54 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 26 Dec 2019 16:49:48 -0500 Subject: [PATCH 11/11] Better added_tokens handling --- src/transformers/tokenization_utils.py | 51 +++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 210e47e752..8fa85a2f7c 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object): class PreTrainedTokenizerFast(PreTrainedTokenizer): + _tokenizer = None + _decoder = None + def __init__(self, **kwargs): super(PreTrainedTokenizerFast, self).__init__(**kwargs) @@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def __len__(self): return self.tokenizer.get_vocab_size(with_added_tokens=True) + @PreTrainedTokenizer.bos_token.setter + def bos_token(self, value): + self._bos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.eos_token.setter + def eos_token(self, value): + self._eos_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.unk_token.setter + def unk_token(self, value): + self._unk_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.sep_token.setter + def sep_token(self, value): + self._sep_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.pad_token.setter + def pad_token(self, value): + self._pad_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.cls_token.setter + def cls_token(self, value): + self._cls_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.mask_token.setter + def mask_token(self, value): + self._mask_token = value + self._update_special_tokens() + + @PreTrainedTokenizer.additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + self._update_special_tokens() + def _update_special_tokens(self): - self.tokenizer.add_special_tokens(self.all_special_tokens) + if self._tokenizer is not None: + self._tokenizer.add_special_tokens(self.all_special_tokens) @staticmethod def _convert_encoding( @@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def add_tokens(self, new_tokens): self.tokenizer.add_tokens(new_tokens) + def add_special_tokens(self, special_tokens_dict): + added = super().add_special_tokens(special_tokens_dict) + self._update_special_tokens() + return added + def encode_batch( self, texts,