improving GPT2 tokenization and adding tests
This commit is contained in:
@@ -38,7 +38,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
@@ -53,19 +53,16 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained("/tmp/")
|
||||
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
os.remove(special_tokens_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user