fixing tests

This commit is contained in:
thomwolf
2019-04-15 12:55:38 +02:00
parent 870b734bfd
commit e8568a3b17
4 changed files with 51 additions and 14 deletions

View File

@@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
@@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
@@ -194,7 +204,11 @@ class GPT2Tokenizer(object):
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
@@ -203,9 +217,14 @@ class GPT2Tokenizer(object):
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
writer.write(token + u'\n')
return vocab_file, merge_file, special_tokens_file
def encode(self, text):
bpe_tokens = []