From e8568a3b17454dd4e0b32b6cd80617aa662cc996 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 12:55:38 +0200 Subject: [PATCH] fixing tests --- pytorch_pretrained_bert/tokenization_gpt2.py | 27 ++++++++++++++++--- .../tokenization_openai.py | 27 ++++++++++++++++--- tests/tokenization_openai_test.py | 2 +- tests/tokenization_transfo_xl_test.py | 9 +++---- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index b49e1310e4..ab80876ee5 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { } VOCAB_NAME = 'vocab.json' MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' @lru_cache() def bytes_to_unicode(): @@ -97,6 +98,11 @@ class GPT2Tokenizer(object): else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -125,7 +131,11 @@ class GPT2Tokenizer(object): max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): @@ -194,7 +204,11 @@ class GPT2Tokenizer(object): return vocab_file = os.path.join(vocab_path, VOCAB_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME) - json.dump(self.encoder, vocab_file) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + index = 0 with open(merge_file, "w", encoding="utf-8") as writer: writer.write(u'#version: 0.2\n') @@ -203,9 +217,14 @@ class GPT2Tokenizer(object): logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!".format(merge_file)) index = token_index - writer.write(bpe_tokens + u'\n') + writer.write(' '.join(bpe_tokens) + u'\n') index += 1 - return vocab_file, merge_file + + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + writer.write(token + u'\n') + + return vocab_file, merge_file, special_tokens_file def encode(self, text): bpe_tokens = [] diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index f3ce7ab251..d9713e51eb 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { } VOCAB_NAME = 'vocab.json' MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' def get_pairs(word): """ @@ -89,6 +90,11 @@ class OpenAIGPTTokenizer(object): else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -117,7 +123,11 @@ class OpenAIGPTTokenizer(object): max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) return tokenizer def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): @@ -269,7 +279,11 @@ class OpenAIGPTTokenizer(object): return vocab_file = os.path.join(vocab_path, VOCAB_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME) - json.dump(self.encoder, vocab_file) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + index = 0 with open(merge_file, "w", encoding="utf-8") as writer: writer.write(u'#version: 0.2\n') @@ -278,6 +292,11 @@ class OpenAIGPTTokenizer(object): logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." " Please check that the tokenizer is not corrupted!".format(merge_file)) index = token_index - writer.write(bpe_tokens + u'\n') + writer.write(' '.join(bpe_tokens) + u'\n') index += 1 - return vocab_file, merge_file + + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]): + writer.write(token + u'\n') + + return vocab_file, merge_file, special_tokens_file diff --git a/tests/tokenization_openai_test.py b/tests/tokenization_openai_test.py index 2b1bdd3a9a..1f695cfb12 100644 --- a/tests/tokenization_openai_test.py +++ b/tests/tokenization_openai_test.py @@ -52,7 +52,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") tokenizer.from_pretrained("/tmp/") os.remove(vocab_file) os.remove(merges_file) diff --git a/tests/tokenization_transfo_xl_test.py b/tests/tokenization_transfo_xl_test.py index add2eb4e71..1a805f11e6 100644 --- a/tests/tokenization_transfo_xl_test.py +++ b/tests/tokenization_transfo_xl_test.py @@ -35,7 +35,7 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.build_vocab() os.remove(vocab_file) - tokens = tokenizer.tokenize(u" UNwant\u00E9d,running") + tokens = tokenizer.tokenize(u" UNwanted , running") self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) self.assertListEqual( @@ -45,7 +45,7 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.from_pretrained(vocab_file) os.remove(vocab_file) - tokens = tokenizer.tokenize(u" UNwant\u00E9d,running") + tokens = tokenizer.tokenize(u" UNwanted , running") self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) self.assertListEqual( @@ -56,15 +56,14 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer = TransfoXLTokenizer(lower_case=True) self.assertListEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), ["hello", "!", "how", "are", "you", "?"]) - self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) def test_full_tokenizer_no_lower(self): tokenizer = TransfoXLTokenizer(lower_case=False) self.assertListEqual( - tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"])