GPT2TokenizerFast
This commit is contained in:
@@ -108,7 +108,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
|
|||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
from .tokenization_ctrl import CTRLTokenizer
|
from .tokenization_ctrl import CTRLTokenizer
|
||||||
from .tokenization_distilbert import DistilBertTokenizer
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -246,3 +246,36 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
return vocab_file, merge_file
|
return vocab_file, merge_file
|
||||||
|
|
||||||
|
class GPT2TokenizerFast(FastPreTrainedTokenizer):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>",
|
||||||
|
eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False,
|
||||||
|
max_length=None, stride=0, truncation_strategy='longest_first', **kwargs):
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
|
||||||
|
|
||||||
|
super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
||||||
|
|
||||||
|
self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file))
|
||||||
|
self._update_special_tokens()
|
||||||
|
self._tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_space))
|
||||||
|
self._tokenizer.with_decoder(decoders.ByteLevel.new())
|
||||||
|
if max_length:
|
||||||
|
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
|
||||||
|
self._tokenizer.with_padding(
|
||||||
|
max_length if pad_to_max_length else None,
|
||||||
|
self.padding_side,
|
||||||
|
self.pad_token_id if self.pad_token_id is not None else 0,
|
||||||
|
self.pad_token_type_id,
|
||||||
|
self.pad_token if self.pad_token is not None else ""
|
||||||
|
)
|
||||||
|
self._decoder = decoders.ByteLevel.new()
|
||||||
|
|
||||||
|
except (AttributeError, ImportError) as e:
|
||||||
|
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
|
||||||
|
raise e
|
||||||
Reference in New Issue
Block a user