Merge pull request #2211 from huggingface/fast-tokenizers
Fast tokenizers
This commit is contained in:
1
setup.py
1
setup.py
@@ -86,6 +86,7 @@ setup(
|
|||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"tokenizers == 0.0.10",
|
||||||
# accessing files from S3 directly
|
# accessing files from S3 directly
|
||||||
"boto3",
|
"boto3",
|
||||||
# filesystem locks e.g. to prevent parallel downloads
|
# filesystem locks e.g. to prevent parallel downloads
|
||||||
|
|||||||
@@ -103,12 +103,12 @@ from .pipelines import (
|
|||||||
)
|
)
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
from .tokenization_auto import AutoTokenizer
|
from .tokenization_auto import AutoTokenizer
|
||||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
from .tokenization_ctrl import CTRLTokenizer
|
from .tokenization_ctrl import CTRLTokenizer
|
||||||
from .tokenization_distilbert import DistilBertTokenizer
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
import tokenizers as tk
|
||||||
|
|
||||||
|
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -525,3 +527,68 @@ def _is_punctuation(char):
|
|||||||
if cat.startswith("P"):
|
if cat.startswith("P"):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerFast(PreTrainedTokenizerFast):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
do_lower_case=True,
|
||||||
|
do_basic_tokenize=True,
|
||||||
|
never_split=None,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
sep_token="[SEP]",
|
||||||
|
pad_token="[PAD]",
|
||||||
|
cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]",
|
||||||
|
tokenize_chinese_chars=True,
|
||||||
|
max_length=None,
|
||||||
|
pad_to_max_length=False,
|
||||||
|
stride=0,
|
||||||
|
truncation_strategy="longest_first",
|
||||||
|
add_special_tokens=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(BertTokenizerFast, self).__init__(
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tokenizer = tk.Tokenizer(tk.models.WordPiece.from_files(vocab_file, unk_token=unk_token))
|
||||||
|
self._update_special_tokens()
|
||||||
|
self._tokenizer.with_pre_tokenizer(
|
||||||
|
tk.pre_tokenizers.BertPreTokenizer.new(
|
||||||
|
do_basic_tokenize=do_basic_tokenize,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
never_split=never_split if never_split is not None else [],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._tokenizer.with_decoder(tk.decoders.WordPiece.new())
|
||||||
|
|
||||||
|
if add_special_tokens:
|
||||||
|
self._tokenizer.with_post_processor(
|
||||||
|
tk.processors.BertProcessing.new(
|
||||||
|
(sep_token, self._tokenizer.token_to_id(sep_token)),
|
||||||
|
(cls_token, self._tokenizer.token_to_id(cls_token)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if max_length is not None:
|
||||||
|
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
|
||||||
|
self._tokenizer.with_padding(
|
||||||
|
max_length=max_length if pad_to_max_length else None,
|
||||||
|
direction=self.padding_side,
|
||||||
|
pad_id=self.pad_token_id,
|
||||||
|
pad_type_id=self.pad_token_type_id,
|
||||||
|
pad_token=self.pad_token,
|
||||||
|
)
|
||||||
|
self._decoder = tk.decoders.WordPiece.new()
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import os
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import tokenizers as tk
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -246,3 +247,42 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
return vocab_file, merge_file
|
return vocab_file, merge_file
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
merges_file,
|
||||||
|
unk_token="<|endoftext|>",
|
||||||
|
bos_token="<|endoftext|>",
|
||||||
|
eos_token="<|endoftext|>",
|
||||||
|
pad_to_max_length=False,
|
||||||
|
add_prefix_space=False,
|
||||||
|
max_length=None,
|
||||||
|
stride=0,
|
||||||
|
truncation_strategy="longest_first",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(GPT2TokenizerFast, self).__init__(
|
||||||
|
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file))
|
||||||
|
self._update_special_tokens()
|
||||||
|
self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space))
|
||||||
|
self._tokenizer.with_decoder(tk.decoders.ByteLevel.new())
|
||||||
|
if max_length:
|
||||||
|
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
|
||||||
|
self._tokenizer.with_padding(
|
||||||
|
max_length=max_length if pad_to_max_length else None,
|
||||||
|
direction=self.padding_side,
|
||||||
|
pad_id=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||||
|
pad_type_id=self.pad_token_type_id,
|
||||||
|
pad_token=self.pad_token if self.pad_token is not None else "",
|
||||||
|
)
|
||||||
|
self._decoder = tk.decoders.ByteLevel.new()
|
||||||
|
|||||||
@@ -1414,3 +1414,199 @@ class PreTrainedTokenizer(object):
|
|||||||
.replace(" 're", "'re")
|
.replace(" 're", "'re")
|
||||||
)
|
)
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
|
|
||||||
|
class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
||||||
|
_tokenizer = None
|
||||||
|
_decoder = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(PreTrainedTokenizerFast, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise NotImplementedError
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoder(self):
|
||||||
|
if self._decoder is None:
|
||||||
|
raise NotImplementedError
|
||||||
|
return self._decoder
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self.tokenizer.get_vocab_size(with_added_tokens=False)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.tokenizer.get_vocab_size(with_added_tokens=True)
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.bos_token.setter
|
||||||
|
def bos_token(self, value):
|
||||||
|
self._bos_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.eos_token.setter
|
||||||
|
def eos_token(self, value):
|
||||||
|
self._eos_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.unk_token.setter
|
||||||
|
def unk_token(self, value):
|
||||||
|
self._unk_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.sep_token.setter
|
||||||
|
def sep_token(self, value):
|
||||||
|
self._sep_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.pad_token.setter
|
||||||
|
def pad_token(self, value):
|
||||||
|
self._pad_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.cls_token.setter
|
||||||
|
def cls_token(self, value):
|
||||||
|
self._cls_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.mask_token.setter
|
||||||
|
def mask_token(self, value):
|
||||||
|
self._mask_token = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
@PreTrainedTokenizer.additional_special_tokens.setter
|
||||||
|
def additional_special_tokens(self, value):
|
||||||
|
self._additional_special_tokens = value
|
||||||
|
self._update_special_tokens()
|
||||||
|
|
||||||
|
def _update_special_tokens(self):
|
||||||
|
if self._tokenizer is not None:
|
||||||
|
self._tokenizer.add_special_tokens(self.all_special_tokens)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_encoding(
|
||||||
|
encoding,
|
||||||
|
return_tensors=None,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
|
):
|
||||||
|
encoding_dict = {
|
||||||
|
"input_ids": encoding.ids,
|
||||||
|
}
|
||||||
|
if return_token_type_ids:
|
||||||
|
encoding_dict["token_type_ids"] = encoding.type_ids
|
||||||
|
if return_attention_mask:
|
||||||
|
encoding_dict["attention_mask"] = encoding.attention_mask
|
||||||
|
if return_overflowing_tokens:
|
||||||
|
overflowing = encoding.overflowing
|
||||||
|
encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else []
|
||||||
|
if return_special_tokens_mask:
|
||||||
|
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
|
||||||
|
|
||||||
|
# Prepare inputs as tensors if asked
|
||||||
|
if return_tensors == "tf" and is_tf_available():
|
||||||
|
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
|
||||||
|
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
|
||||||
|
|
||||||
|
if "attention_mask" in encoding_dict:
|
||||||
|
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
|
||||||
|
|
||||||
|
elif return_tensors == "pt" and is_torch_available():
|
||||||
|
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
|
||||||
|
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
|
||||||
|
|
||||||
|
if "attention_mask" in encoding_dict:
|
||||||
|
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
|
||||||
|
elif return_tensors is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||||
|
return_tensors
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoding_dict
|
||||||
|
|
||||||
|
def encode_plus(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
text_pair=None,
|
||||||
|
return_tensors=None,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoding = self.tokenizer.encode(text, text_pair)
|
||||||
|
return self._convert_encoding(
|
||||||
|
encoding,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
return_token_type_ids=return_token_type_ids,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
return_overflowing_tokens=return_overflowing_tokens,
|
||||||
|
return_special_tokens_mask=return_special_tokens_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
return self.tokenizer.encode(text).tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id_with_added_voc(self, token):
|
||||||
|
id = self.tokenizer.token_to_id(token)
|
||||||
|
if id is None:
|
||||||
|
return self.unk_token_id
|
||||||
|
return id
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
return self.tokenizer.id_to_token(int(index))
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
return self.decoder.decode(tokens)
|
||||||
|
|
||||||
|
def add_tokens(self, new_tokens):
|
||||||
|
self.tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
def add_special_tokens(self, special_tokens_dict):
|
||||||
|
added = super().add_special_tokens(special_tokens_dict)
|
||||||
|
self._update_special_tokens()
|
||||||
|
return added
|
||||||
|
|
||||||
|
def encode_batch(
|
||||||
|
self,
|
||||||
|
texts,
|
||||||
|
return_tensors=None,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
|
):
|
||||||
|
return [
|
||||||
|
self._convert_encoding(
|
||||||
|
encoding,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
return_token_type_ids=return_token_type_ids,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
return_overflowing_tokens=return_overflowing_tokens,
|
||||||
|
return_special_tokens_mask=return_special_tokens_mask,
|
||||||
|
)
|
||||||
|
for encoding in self.tokenizer.encode_batch(texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
|
text = self.tokenizer.decode(token_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
if clean_up_tokenization_spaces:
|
||||||
|
clean_text = self.clean_up_tokenization(text)
|
||||||
|
return clean_text
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
|
||||||
|
return [
|
||||||
|
self.clean_up_tokenization(text) if clear_up_tokenization_spaces else text
|
||||||
|
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)
|
||||||
|
]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
|
|||||||
VOCAB_FILES_NAMES,
|
VOCAB_FILES_NAMES,
|
||||||
BasicTokenizer,
|
BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
BertTokenizerFast,
|
||||||
WordpieceTokenizer,
|
WordpieceTokenizer,
|
||||||
_is_control,
|
_is_control,
|
||||||
_is_punctuation,
|
_is_punctuation,
|
||||||
@@ -34,6 +35,7 @@ from .utils import slow
|
|||||||
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = BertTokenizer
|
tokenizer_class = BertTokenizer
|
||||||
|
test_rust_tokenizer = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(BertTokenizationTest, self).setUp()
|
super(BertTokenizationTest, self).setUp()
|
||||||
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = "UNwant\u00E9d,running"
|
input_text = "UNwant\u00E9d,running"
|
||||||
output_text = "unwanted, running"
|
output_text = "unwanted, running"
|
||||||
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
def test_rust_and_python_full_tokenizers(self):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
|
||||||
|
|
||||||
|
sequence = u"UNwant\u00E9d,running"
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(sequence)
|
||||||
|
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||||
|
self.assertListEqual(tokens, rust_tokens)
|
||||||
|
|
||||||
|
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
ids = tokenizer.encode(sequence)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import tempfile
|
|||||||
class TokenizerTesterMixin:
|
class TokenizerTesterMixin:
|
||||||
|
|
||||||
tokenizer_class = None
|
tokenizer_class = None
|
||||||
|
test_rust_tokenizer = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
|
|||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
|
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
|
|||||||
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = GPT2Tokenizer
|
tokenizer_class = GPT2Tokenizer
|
||||||
|
test_rust_tokenizer = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(GPT2TokenizationTest, self).setUp()
|
super(GPT2TokenizationTest, self).setUp()
|
||||||
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = "lower newer"
|
input_text = "lower newer"
|
||||||
output_text = "lower newer"
|
output_text = "lower newer"
|
||||||
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
input_tokens = tokens + [tokenizer.unk_token]
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
def test_rust_and_python_full_tokenizers(self):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
|
||||||
|
|
||||||
|
sequence = u"lower newer"
|
||||||
|
|
||||||
|
# Testing tokenization
|
||||||
|
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
||||||
|
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||||
|
self.assertListEqual(tokens, rust_tokens)
|
||||||
|
|
||||||
|
# Testing conversion to ids without special tokens
|
||||||
|
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
# Testing conversion to ids with special tokens
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
||||||
|
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
||||||
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
# Testing the unknown token
|
||||||
|
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||||
|
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||||
|
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user