add serialization semantics to tokenizers - fix transfo-xl tokenizer

This commit is contained in:
thomwolf
2019-04-15 11:47:25 +02:00
parent 616743330e
commit 3e65f255dc
5 changed files with 67 additions and 110 deletions

View File

@@ -134,6 +134,19 @@ class BertTokenizer(object):
tokens.append(self.ids_to_tokens[i])
return tokens
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file))
index = token_index
writer.write(token + u'\n')
index += 1
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""