GPT2TokenizerFast
This commit is contained in:
@@ -108,7 +108,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
|
||||
@@ -22,7 +22,7 @@ from functools import lru_cache
|
||||
|
||||
import regex as re
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -246,3 +246,36 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
class GPT2TokenizerFast(FastPreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>",
|
||||
eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False,
|
||||
max_length=None, stride=0, truncation_strategy='longest_first', **kwargs):
|
||||
|
||||
try:
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
|
||||
|
||||
super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
||||
|
||||
self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file))
|
||||
self._update_special_tokens()
|
||||
self._tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_space))
|
||||
self._tokenizer.with_decoder(decoders.ByteLevel.new())
|
||||
if max_length:
|
||||
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
|
||||
self._tokenizer.with_padding(
|
||||
max_length if pad_to_max_length else None,
|
||||
self.padding_side,
|
||||
self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
self.pad_token_type_id,
|
||||
self.pad_token if self.pad_token is not None else ""
|
||||
)
|
||||
self._decoder = decoders.ByteLevel.new()
|
||||
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
|
||||
raise e
|
||||
Reference in New Issue
Block a user