add serialization semantics to tokenizers - fix transfo-xl tokenizer

This commit is contained in:
thomwolf
2019-04-15 11:47:25 +02:00
parent 616743330e
commit 3e65f255dc
5 changed files with 67 additions and 110 deletions

View File

@@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
self.cache[token] = word
return word
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):