improving GPT2 tokenization and adding tests

This commit is contained in:
thomwolf
2019-04-16 17:00:55 +02:00
parent 3d78e226e6
commit 18a8a15f78
5 changed files with 169 additions and 34 deletions

View File

@@ -150,6 +150,8 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
@@ -261,7 +263,10 @@ class OpenAIGPTTokenizer(object):
tokens.append(self.decoder[i])
return tokens
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip()
@@ -296,8 +301,14 @@ class OpenAIGPTTokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file