added tokenizers serialization tests

This commit is contained in:
thomwolf
2019-04-15 12:03:56 +02:00
parent 3e65f255dc
commit 870b734bfd
7 changed files with 51 additions and 32 deletions

View File

@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
return word
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
index = token_index
writer.write(bpe_tokens + u'\n')
index += 1
return vocab_file, merge_file
def encode(self, text):
bpe_tokens = []