improving GPT2 tokenization and adding tests
This commit is contained in:
@@ -929,10 +929,11 @@ This class has four arguments:
|
|||||||
|
|
||||||
and five methods:
|
and five methods:
|
||||||
|
|
||||||
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
|
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing BPE tokenization.
|
||||||
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
|
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
|
||||||
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
||||||
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
|
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
|
||||||
|
- `encode(text)`: convert a `str` in a list of `int` tokens by performing BPE encoding.
|
||||||
- `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces.
|
- `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces.
|
||||||
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
||||||
|
|
||||||
@@ -958,6 +959,10 @@ This class has three arguments:
|
|||||||
|
|
||||||
and two methods:
|
and two methods:
|
||||||
|
|
||||||
|
- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing byte-level BPE.
|
||||||
|
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
|
||||||
|
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
|
||||||
|
- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments)
|
||||||
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
|
- `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE.
|
||||||
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
|
- `decode(tokens)`: convert back a list of `int` tokens in a `str`.
|
||||||
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
- `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -138,7 +139,7 @@ class GPT2Tokenizer(object):
|
|||||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
self.encoder = json.load(open(vocab_file))
|
self.encoder = json.load(open(vocab_file))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
@@ -153,8 +154,25 @@ class GPT2Tokenizer(object):
|
|||||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||||
|
|
||||||
|
self.special_tokens = {}
|
||||||
|
self.special_tokens_decoder = {}
|
||||||
|
self.set_special_tokens(special_tokens)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.encoder)
|
return len(self.encoder) + len(self.special_tokens)
|
||||||
|
|
||||||
|
def set_special_tokens(self, special_tokens):
|
||||||
|
""" Add a list of additional tokens to the encoder.
|
||||||
|
The additional tokens are indexed starting from the last index of the
|
||||||
|
current vocabulary in the order of the `special_tokens` list.
|
||||||
|
"""
|
||||||
|
if not special_tokens:
|
||||||
|
self.special_tokens = {}
|
||||||
|
self.special_tokens_decoder = {}
|
||||||
|
return
|
||||||
|
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||||
|
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||||
|
logger.info("Special tokens {}".format(self.special_tokens))
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
if token in self.cache:
|
if token in self.cache:
|
||||||
@@ -197,6 +215,54 @@ class GPT2Tokenizer(object):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
""" Tokenize a string. """
|
||||||
|
bpe_tokens = []
|
||||||
|
for token in re.findall(self.pat, text):
|
||||||
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
""" Converts a sequence of tokens into ids using the vocab. """
|
||||||
|
ids = []
|
||||||
|
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||||
|
if tokens in self.special_tokens:
|
||||||
|
return self.special_tokens[tokens]
|
||||||
|
else:
|
||||||
|
return self.encoder.get(tokens, 0)
|
||||||
|
for token in tokens:
|
||||||
|
if token in self.special_tokens:
|
||||||
|
ids.append(self.special_tokens[token])
|
||||||
|
else:
|
||||||
|
ids.append(self.encoder.get(token, 0))
|
||||||
|
if len(ids) > self.max_len:
|
||||||
|
logger.warning(
|
||||||
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
|
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||||
|
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||||
|
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||||
|
tokens = []
|
||||||
|
for i in ids:
|
||||||
|
if i in self.special_tokens_decoder:
|
||||||
|
if not skip_special_tokens:
|
||||||
|
tokens.append(self.special_tokens_decoder[i])
|
||||||
|
else:
|
||||||
|
tokens.append(self.decoder[i])
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
|
return text
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||||
if not os.path.isdir(vocab_path):
|
if not os.path.isdir(vocab_path):
|
||||||
@@ -220,26 +286,14 @@ class GPT2Tokenizer(object):
|
|||||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
index = len(self.encoder)
|
||||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||||
|
index = token_index
|
||||||
writer.write(token + u'\n')
|
writer.write(token + u'\n')
|
||||||
|
index += 1
|
||||||
|
|
||||||
return vocab_file, merge_file, special_tokens_file
|
return vocab_file, merge_file, special_tokens_file
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
bpe_tokens = []
|
|
||||||
for token in re.findall(self.pat, text):
|
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
||||||
if len(bpe_tokens) > self.max_len:
|
|
||||||
logger.warning(
|
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
|
||||||
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
|
|
||||||
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
|
|
||||||
)
|
|
||||||
return bpe_tokens
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
|
||||||
return text
|
|
||||||
|
|||||||
@@ -150,6 +150,8 @@ class OpenAIGPTTokenizer(object):
|
|||||||
merges = [tuple(merge.split()) for merge in merges]
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.special_tokens = {}
|
||||||
|
self.special_tokens_decoder = {}
|
||||||
self.set_special_tokens(special_tokens)
|
self.set_special_tokens(special_tokens)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -261,7 +263,10 @@ class OpenAIGPTTokenizer(object):
|
|||||||
tokens.append(self.decoder[i])
|
tokens.append(self.decoder[i])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False):
|
def encode(self, text):
|
||||||
|
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||||
|
|
||||||
|
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||||
@@ -296,8 +301,14 @@ class OpenAIGPTTokenizer(object):
|
|||||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
index = len(self.encoder)
|
||||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
|
||||||
|
index = token_index
|
||||||
writer.write(token + u'\n')
|
writer.write(token + u'\n')
|
||||||
|
index += 1
|
||||||
|
|
||||||
return vocab_file, merge_file, special_tokens_file
|
return vocab_file, merge_file, special_tokens_file
|
||||||
|
|||||||
68
tests/tokenization_gpt2_test.py
Normal file
68
tests/tokenization_gpt2_test.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2TokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||||
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
|
"lo", "low", "er",
|
||||||
|
"low", "lowest", "newer", "wider"]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||||
|
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||||
|
json.dump(vocab_tokens, fp)
|
||||||
|
vocab_file = fp.name
|
||||||
|
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
merges_file = fp.name
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
|
||||||
|
text = "lower"
|
||||||
|
bpe_tokens = ["low", "er"]
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens + ["<unk>"]
|
||||||
|
input_bpe_tokens = [13, 12, 16]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
|
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/")
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
os.remove(special_tokens_file)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||||
|
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||||
|
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||||
|
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -38,7 +38,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
merges_file = fp.name
|
merges_file = fp.name
|
||||||
|
|
||||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
|
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||||
os.remove(vocab_file)
|
os.remove(vocab_file)
|
||||||
os.remove(merges_file)
|
os.remove(merges_file)
|
||||||
|
|
||||||
@@ -53,19 +53,16 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
tokenizer.from_pretrained("/tmp/")
|
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
|
||||||
os.remove(vocab_file)
|
os.remove(vocab_file)
|
||||||
os.remove(merges_file)
|
os.remove(merges_file)
|
||||||
|
os.remove(special_tokens_file)
|
||||||
|
|
||||||
text = "lower"
|
|
||||||
bpe_tokens = ["low", "er</w>"]
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
|
||||||
|
|
||||||
input_tokens = tokens + ["<unk>"]
|
|
||||||
input_bpe_tokens = [14, 15, 20]
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||||
|
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||||
|
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||||
|
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user