From 5dd7b677adbd2a228328e42b79583143c16b8dff Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 12:43:08 +0200 Subject: [PATCH] clean up all byte-level bpe tests --- pytorch_transformers/tests/tokenization_gpt2_test.py | 4 ++-- pytorch_transformers/tests/tokenization_roberta_test.py | 6 +++--- pytorch_transformers/tokenization_gpt2.py | 9 +++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index 3e4fb5bc1d..fbeaa2a6df 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -57,12 +57,12 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): def test_full_tokenizer(self): tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" - bpe_tokens = ["\u0120low", "er", "\u0120newer"] + bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] tokens = tokenizer.tokenize(text) self.assertListEqual(tokens, bpe_tokens) input_tokens = tokens + [tokenizer.unk_token] - input_bpe_tokens = [14, 15, 19] + input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index e2082e7613..e1e7f45efd 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -55,13 +55,13 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): def test_full_tokenizer(self): tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) - text = "lower" - bpe_tokens = ["\u0120low", "er"] + text = "lower newer" + bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] tokens = tokenizer.tokenize(text) self.assertListEqual(tokens, bpe_tokens) input_tokens = tokens + [tokenizer.unk_token] - input_bpe_tokens = [14, 15, 19] + input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index 1fa7cbd06b..8a9ade87e1 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -64,13 +64,14 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @lru_cache() def bytes_to_unicode(): """ - Returns list of utf-8 byte and a corresponding list of unicode strings. + Returns list of utf-8 byte and a mapping to unicode strings. + We specifically avoids mapping to whitespace/control characters the bpe code barfs on. + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. """ _chr = unichr if sys.version_info[0] == 2 else chr bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) @@ -176,9 +177,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): bpe_tokens = [] for token in re.findall(self.pat, text): if sys.version_info[0] == 2: - token = ''.join(self.byte_encoder[ord(b)] for b in token) + token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) else: - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) return bpe_tokens