From 18a8a15f78a10ac6bf272bc762232b3f16df30e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 16 Apr 2019 17:00:55 +0200 Subject: [PATCH] improving GPT2 tokenization and adding tests --- README.md | 7 +- pytorch_pretrained_bert/tokenization_gpt2.py | 96 +++++++++++++++---- .../tokenization_openai.py | 15 ++- tests/tokenization_gpt2_test.py | 68 +++++++++++++ tests/tokenization_openai_test.py | 17 ++-- 5 files changed, 169 insertions(+), 34 deletions(-) create mode 100644 tests/tokenization_gpt2_test.py diff --git a/README.md b/README.md index caf415508f..fde35d23ea 100644 --- a/README.md +++ b/README.md @@ -929,10 +929,11 @@ This class has four arguments: and five methods: -- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization. +- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing BPE tokenization. - `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary. - `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary. - `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments) +- `encode(text)`: convert a `str` in a list of `int` tokens by performing BPE encoding. - `decode(ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)`: decode a list of `int` indices in a string and do some post-processing if needed: (i) remove special tokens from the output and (ii) clean up tokenization spaces. - `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`. @@ -958,6 +959,10 @@ This class has three arguments: and two methods: +- `tokenize(text)`: convert a `str` in a list of `str` tokens by performing byte-level BPE. +- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary. +- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary. +- `set_special_tokens(self, special_tokens)`: update the list of special tokens (see above arguments) - `encode(text)`: convert a `str` in a list of `int` tokens by performing byte-level BPE. - `decode(tokens)`: convert back a list of `int` tokens in a `str`. - `save_vocabulary(directory_path)`: save the vocabulary, merge and special tokens files to `directory_path`. Return the path to the three files: `vocab_file_path`, `merge_file_path`, `special_tokens_file_path`. The vocabulary can be reloaded with `OpenAIGPTTokenizer.from_pretrained('directory_path')`. diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index ab80876ee5..491db616e4 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -16,6 +16,7 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) +import sys import json import logging import os @@ -138,7 +139,7 @@ class GPT2Tokenizer(object): tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer - def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): + def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} @@ -153,8 +154,25 @@ class GPT2Tokenizer(object): # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + def __len__(self): - return len(self.encoder) + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) def bpe(self, token): if token in self.cache: @@ -197,6 +215,54 @@ class GPT2Tokenizer(object): self.cache[token] = word return word + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(vocab_path): @@ -220,26 +286,14 @@ class GPT2Tokenizer(object): writer.write(' '.join(bpe_tokens) + u'\n') index += 1 + index = len(self.encoder) with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index writer.write(token + u'\n') + index += 1 return vocab_file, merge_file, special_tokens_file - - def encode(self, text): - bpe_tokens = [] - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - if len(bpe_tokens) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) - ) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) - return text diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 7a10271175..1088b5222b 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -150,6 +150,8 @@ class OpenAIGPTTokenizer(object): merges = [tuple(merge.split()) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + self.special_tokens = {} + self.special_tokens_decoder = {} self.set_special_tokens(special_tokens) def __len__(self): @@ -261,7 +263,10 @@ class OpenAIGPTTokenizer(object): tokens.append(self.decoder[i]) return tokens - def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): """Converts a sequence of ids in a string.""" tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) out_string = ''.join(tokens).replace('', ' ').strip() @@ -296,8 +301,14 @@ class OpenAIGPTTokenizer(object): writer.write(' '.join(bpe_tokens) + u'\n') index += 1 + index = len(self.encoder) with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index writer.write(token + u'\n') + index += 1 return vocab_file, merge_file, special_tokens_file diff --git a/tests/tokenization_gpt2_test.py b/tests/tokenization_gpt2_test.py new file mode 100644 index 0000000000..29633bc17c --- /dev/null +++ b/tests/tokenization_gpt2_test.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +import json + +from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer + + +class GPT2TokenizationTest(unittest.TestCase): + + def test_full_tokenizer(self): + """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", + "lo", "low", "er", + "low", "lowest", "newer", "wider"] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "l o", "lo w", "e r", ""] + with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: + json.dump(vocab_tokens, fp) + vocab_file = fp.name + with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: + fp.write("\n".join(merges)) + merges_file = fp.name + + tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["", ""]) + os.remove(vocab_file) + os.remove(merges_file) + + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [13, 12, 16] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/") + os.remove(vocab_file) + os.remove(merges_file) + os.remove(special_tokens_file) + + self.assertListEqual( + [tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, + tokenizer.special_tokens, tokenizer.special_tokens_decoder], + [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, + tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tokenization_openai_test.py b/tests/tokenization_openai_test.py index 1f695cfb12..fb42cdd8cb 100644 --- a/tests/tokenization_openai_test.py +++ b/tests/tokenization_openai_test.py @@ -38,7 +38,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): fp.write("\n".join(merges)) merges_file = fp.name - tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=[""]) + tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["", ""]) os.remove(vocab_file) os.remove(merges_file) @@ -53,19 +53,16 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") - tokenizer.from_pretrained("/tmp/") + tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/") os.remove(vocab_file) os.remove(merges_file) + os.remove(special_tokens_file) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) - - input_tokens = tokens + [""] - input_bpe_tokens = [14, 15, 20] self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + [tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, + tokenizer.special_tokens, tokenizer.special_tokens_decoder], + [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, + tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) if __name__ == '__main__':