add serialization semantics to tokenizers - fix transfo-xl tokenizer

This commit is contained in:
thomwolf
2019-04-15 11:47:25 +02:00
parent 616743330e
commit 3e65f255dc
5 changed files with 67 additions and 110 deletions

View File

@@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
).replace(" 've", "'ve")
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1