added tokenizers serialization tests

This commit is contained in:
thomwolf
2019-04-15 12:03:56 +02:00
parent 3e65f255dc
commit 870b734bfd
7 changed files with 51 additions and 32 deletions

View File

@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1
return vocab_file, merge_file