clean up tokenization - fix python 2 tests

This commit is contained in:
thomwolf
2019-02-18 11:27:18 +01:00
parent d44db1145c
commit b450a7faf2

View File

@@ -20,14 +20,19 @@ import json
import logging import logging
import os import os
import regex as re import regex as re
import sys
from io import open from io import open
from functools import lru_cache
from tqdm import tqdm try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache(func):
def func_wrapper(*inputs, **args):
return func(inputs, args)
return func_wrapper
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -125,7 +130,8 @@ class GPT2Tokenizer(object):
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace'): def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
@@ -188,6 +194,12 @@ class GPT2Tokenizer(object):
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens return bpe_tokens
def decode(self, tokens): def decode(self, tokens):