diff --git a/README.md b/README.md index c31bbd24b7..c130675dbd 100644 --- a/README.md +++ b/README.md @@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules -The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. -The new optimizer `AdamW` matches PyTorch `Adam` optimizer API. +The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences: + +- it only implements weights decay correction, +- schedules are now externals (see below), +- gradient clipping is now also external (see below). + +The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. @@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ```python # Parameters: lr = 1e-3 +max_grad_norm = 1.0 num_total_steps = 1000 num_warmup_steps = 100 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 @@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot for batch in train_data: loss = model(batch) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue) scheduler.step() optimizer.step() ``` diff --git a/docs/source/migration.md b/docs/source/migration.md index ba09253472..9cfcaade13 100644 --- a/docs/source/migration.md +++ b/docs/source/migration.md @@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules -The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. -The new optimizer `AdamW` matches PyTorch `Adam` optimizer API. +The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences: + +- it only implements weights decay correction, +- schedules are now externals (see below), +- gradient clipping is now also external (see below). + +The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. @@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ```python # Parameters: lr = 1e-3 +max_grad_norm = 1.0 num_total_steps = 1000 num_warmup_steps = 100 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 @@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot for batch in train_data: loss = model(batch) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue) scheduler.step() optimizer.step() ``` diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 61854f61ea..7117d7ffa6 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to .. code-block:: python - from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME + from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME output_dir = "./models/" diff --git a/docs/source/torchscript.rst b/docs/source/torchscript.rst index 1b84559567..3c38177353 100644 --- a/docs/source/torchscript.rst +++ b/docs/source/torchscript.rst @@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` .. code-block:: python - from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig + from pytorch_transformers import BertModel, BertTokenizer, BertConfig import torch enc = BertTokenizer.from_pretrained("bert-base-uncased") @@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` # The model needs to be in evaluation mode model.eval() + # If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag + model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) + # Creating the trace traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) torch.jit.save(traced_model, "traced_bert.pt") diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index c9b0aeebb7..72d666448e 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) -from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) +from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path) diff --git a/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py b/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py index b8858ee3dc..d866365fd0 100644 --- a/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py +++ b/pytorch_transformers/convert_pytorch_checkpoint_to_tf.py @@ -20,7 +20,7 @@ import argparse import torch import numpy as np import tensorflow as tf -from pytorch_pretrained_bert.modeling import BertModel +from pytorch_transformers.modeling import BertModel def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index fd655cec0e..75c075720c 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -38,10 +38,13 @@ except ImportError: try: from pathlib import Path PYTORCH_PRETRAINED_BERT_CACHE = Path( - os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) + os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) except (AttributeError, ImportError): - PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - default_cache_path) + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + default_cache_path)) + +PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None): Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. """ if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + cache_dir = PYTORCH_TRANSFORMERS_CACHE if sys.version_info[0] == 3 and isinstance(cache_dir, Path): cache_dir = str(cache_dir) @@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None): make sure the file exists and then return the path. """ if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + cache_dir = PYTORCH_TRANSFORMERS_CACHE if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): url_or_filename = str(url_or_filename) if sys.version_info[0] == 3 and isinstance(cache_dir, Path): @@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None): If it's not there, download it. Then return the path to the cached file. """ if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + cache_dir = PYTORCH_TRANSFORMERS_CACHE if sys.version_info[0] == 3 and isinstance(cache_dir, Path): cache_dir = str(cache_dir) if sys.version_info[0] == 2 and not isinstance(cache_dir, str): diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index 0b9cfb1b32..5eb39b729d 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -24,30 +24,37 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, _is_control, _is_punctuation, _is_whitespace, VOCAB_FILES_NAMES) -from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from .tokenization_tests_commons import CommonTestCases -class TokenizationTest(unittest.TestCase): +class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): + + tokenizer_class = BertTokenizer + + def setUp(self): + super(BertTokenizationTest, self).setUp() - def test_full_tokenizer(self): vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] - with TemporaryDirectory() as tmpdirname: - vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) - with open(vocab_file, "w", encoding='utf-8') as vocab_writer: - vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - input_text = u"UNwant\u00E9d,running" - output_text = u"unwanted, running" + def get_tokenizer(self): + return BertTokenizer.from_pretrained(self.tmpdirname) - create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname) + def get_input_output_texts(self): + input_text = u"UNwant\u00E9d,running" + output_text = u"unwanted, running" + return input_text, output_text - tokenizer = BertTokenizer(vocab_file) + def test_full_tokenizer(self): + tokenizer = BertTokenizer(self.vocab_file) - tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") - self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) - self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) + tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index 8dae72ec99..da7028c27d 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -20,42 +20,49 @@ import json from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES -from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from .tokenization_tests_commons import CommonTestCases -class GPT2TokenizationTest(unittest.TestCase): +class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): - def test_full_tokenizer(self): - """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + tokenizer_class = GPT2Tokenizer + + def setUp(self): + super(GPT2TokenizationTest, self).setUp() + + # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "low", "er", "low", "lowest", "newer", "wider", ""] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "l o", "lo w", "e r", ""] - special_tokens_map = {"unk_token": ""} + self.special_tokens_map = {"unk_token": ""} - with TemporaryDirectory() as tmpdirname: - vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) - merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) - with open(vocab_file, "w") as fp: - fp.write(json.dumps(vocab_tokens)) - with open(merges_file, "w") as fp: - fp.write("\n".join(merges)) + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(self.vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(self.merges_file, "w") as fp: + fp.write("\n".join(merges)) - input_text = u"lower newer" - output_text = u"lowernewer" + def get_tokenizer(self): + return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) - create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map) + def get_input_output_texts(self): + input_text = u"lower newer" + output_text = u"lowernewer" + return input_text, output_text - tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + def test_full_tokenizer(self): + tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) - input_tokens = tokens + [tokenizer.unk_token] - input_bpe_tokens = [13, 12, 17] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + input_tokens = tokens + [tokenizer.unk_token] + input_bpe_tokens = [13, 12, 17] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index 9b4841a605..bb354f3fb7 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -20,13 +20,17 @@ import json from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES -from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from .tokenization_tests_commons import CommonTestCases -class OpenAIGPTTokenizationTest(unittest.TestCase): +class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): - def test_full_tokenizer(self): - """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + tokenizer_class = OpenAIGPTTokenizer + + def setUp(self): + super(OpenAIGPTTokenizationTest, self).setUp() + + # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "w", "r", "t", "lo", "low", "er", @@ -34,30 +38,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "l o", "lo w", "e r", ""] - with TemporaryDirectory() as tmpdirname: - vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) - merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) - with open(vocab_file, "w") as fp: - fp.write(json.dumps(vocab_tokens)) - with open(merges_file, "w") as fp: - fp.write("\n".join(merges)) + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(self.vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(self.merges_file, "w") as fp: + fp.write("\n".join(merges)) - input_text = u"lower newer" - output_text = u"lower newer" + def get_tokenizer(self): + return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) - create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname) + def get_input_output_texts(self): + input_text = u"lower newer" + output_text = u"lower newer" + return input_text, output_text - tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + def test_full_tokenizer(self): + tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) - input_tokens = tokens + [""] - input_bpe_tokens = [14, 15, 20] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index c37770b229..ebcf6f48d8 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -19,6 +19,7 @@ import sys from io import open import tempfile import shutil +import unittest if sys.version_info[0] == 2: import cPickle as pickle @@ -36,113 +37,124 @@ else: unicode = str -def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) +class CommonTestCases: - before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") + class CommonTokenizerTester(unittest.TestCase): - with TemporaryDirectory() as tmpdirname: - tokenizer.save_pretrained(tmpdirname) - tokenizer = tokenizer.from_pretrained(tmpdirname) + tokenizer_class = None - after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - tester.assertListEqual(before_tokens, after_tokens) + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() -def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) - tester.assertIsNotNone(tokenizer) + def tearDown(self): + shutil.rmtree(self.tmpdirname) - text = u"Munich and Berlin are nice cities" - subwords = tokenizer.tokenize(text) + def get_tokenizer(self): + raise NotImplementedError - with TemporaryDirectory() as tmpdirname: + def get_input_output_texts(self): + raise NotImplementedError - filename = os.path.join(tmpdirname, u"tokenizer.bin") - pickle.dump(tokenizer, open(filename, "wb")) + def test_save_and_load_tokenizer(self): + tokenizer = self.get_tokenizer() - tokenizer_new = pickle.load(open(filename, "rb")) + before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - subwords_loaded = tokenizer_new.tokenize(text) + with TemporaryDirectory() as tmpdirname: + tokenizer.save_pretrained(tmpdirname) + tokenizer = tokenizer.from_pretrained(tmpdirname) - tester.assertListEqual(subwords, subwords_loaded) + after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") + self.assertListEqual(before_tokens, after_tokens) + + def test_pickle_tokenizer(self): + tokenizer = self.get_tokenizer() + self.assertIsNotNone(tokenizer) + + text = u"Munich and Berlin are nice cities" + subwords = tokenizer.tokenize(text) + + with TemporaryDirectory() as tmpdirname: + + filename = os.path.join(tmpdirname, u"tokenizer.bin") + pickle.dump(tokenizer, open(filename, "wb")) + + tokenizer_new = pickle.load(open(filename, "rb")) + + subwords_loaded = tokenizer_new.tokenize(text) + + self.assertListEqual(subwords, subwords_loaded) -def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) + def test_add_tokens_tokenizer(self): + tokenizer = self.get_tokenizer() - vocab_size = tokenizer.vocab_size - all_size = len(tokenizer) + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) - tester.assertNotEqual(vocab_size, 0) - tester.assertEqual(vocab_size, all_size) + self.assertNotEqual(vocab_size, 0) + self.assertEqual(vocab_size, all_size) - new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] - added_toks = tokenizer.add_tokens(new_toks) - vocab_size_2 = tokenizer.vocab_size - all_size_2 = len(tokenizer) + new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) - tester.assertNotEqual(vocab_size_2, 0) - tester.assertEqual(vocab_size, vocab_size_2) - tester.assertEqual(added_toks, len(new_toks)) - tester.assertEqual(all_size_2, all_size + len(new_toks)) + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) - tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") - tester.assertGreaterEqual(len(tokens), 4) - tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) - tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) - new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", - 'pad_token': "<<<<<|||>|>>>>|>"} - added_toks_2 = tokenizer.add_special_tokens(new_toks_2) - vocab_size_3 = tokenizer.vocab_size - all_size_3 = len(tokenizer) + new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", + 'pad_token': "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) - tester.assertNotEqual(vocab_size_3, 0) - tester.assertEqual(vocab_size, vocab_size_3) - tester.assertEqual(added_toks_2, len(new_toks_2)) - tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + self.assertNotEqual(vocab_size_3, 0) + self.assertEqual(vocab_size, vocab_size_3) + self.assertEqual(added_toks_2, len(new_toks_2)) + self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) - tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") + tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") - tester.assertGreaterEqual(len(tokens), 6) - tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) - tester.assertGreater(tokens[0], tokens[1]) - tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) - tester.assertGreater(tokens[-2], tokens[-3]) - tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) - tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) + self.assertGreaterEqual(len(tokens), 6) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[0], tokens[1]) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokens[-3]) + self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) + self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) -def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) + def test_required_methods_tokenizer(self): + tokenizer = self.get_tokenizer() + input_text, output_text = self.get_input_output_texts() - tokens = tokenizer.tokenize(input_text) - ids = tokenizer.convert_tokens_to_ids(tokens) - ids_2 = tokenizer.encode(input_text) - tester.assertListEqual(ids, ids_2) + tokens = tokenizer.tokenize(input_text) + ids = tokenizer.convert_tokens_to_ids(tokens) + ids_2 = tokenizer.encode(input_text) + self.assertListEqual(ids, ids_2) - tokens_2 = tokenizer.convert_ids_to_tokens(ids) - text_2 = tokenizer.decode(ids) + tokens_2 = tokenizer.convert_ids_to_tokens(ids) + text_2 = tokenizer.decode(ids) - tester.assertEqual(text_2, output_text) + self.assertEqual(text_2, output_text) - tester.assertNotEqual(len(tokens_2), 0) - tester.assertIsInstance(text_2, (str, unicode)) + self.assertNotEqual(len(tokens_2), 0) + self.assertIsInstance(text_2, (str, unicode)) -def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): - weights_list = list(tokenizer_class.max_model_input_sizes.keys()) - weights_lists_2 = [] - for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items(): - weights_lists_2.append(list(map_list.keys())) + def test_pretrained_model_lists(self): + weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) + weights_lists_2 = [] + for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items(): + weights_lists_2.append(list(map_list.keys())) - for weights_list_2 in weights_lists_2: - tester.assertListEqual(weights_list, weights_list_2) - - -def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): - create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) - create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) - create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) - create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) - create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) + for weights_list_2 in weights_lists_2: + self.assertListEqual(weights_list, weights_list_2) diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index aecfeaef5f..fbd06cf47e 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -20,32 +20,39 @@ from io import open from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES -from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from.tokenization_tests_commons import CommonTestCases -class TransfoXLTokenizationTest(unittest.TestCase): +class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): + + tokenizer_class = TransfoXLTokenizer + + def setUp(self): + super(TransfoXLTokenizationTest, self).setUp() - def test_full_tokenizer(self): vocab_tokens = [ "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ",", "low", "l", ] - with TemporaryDirectory() as tmpdirname: - vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) - with open(vocab_file, "w", encoding='utf-8') as vocab_writer: - vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - input_text = u" UNwanted , running" - output_text = u" unwanted, running" + def get_tokenizer(self): + return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) - create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True) + def get_input_output_texts(self): + input_text = u" UNwanted , running" + output_text = u" unwanted, running" + return input_text, output_text - tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) + def test_full_tokenizer(self): + tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) - tokens = tokenizer.tokenize(u" UNwanted , running") - self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) + tokens = tokenizer.tokenize(u" UNwanted , running") + self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) - self.assertListEqual( - tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) def test_full_tokenizer_lower(self): tokenizer = TransfoXLTokenizer(lower_case=True) diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index 97e8fa983f..a20e92044f 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -20,12 +20,16 @@ import json from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES -from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from .tokenization_tests_commons import CommonTestCases -class XLMTokenizationTest(unittest.TestCase): +class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): - def test_full_tokenizer(self): - """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + tokenizer_class = XLMTokenizer + + def setUp(self): + super(XLMTokenizationTest, self).setUp() + + # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "w", "r", "t", "lo", "low", "er", @@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase): vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["l o 123", "lo w 1456", "e r 1789", ""] - with TemporaryDirectory() as tmpdirname: - vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) - merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) - with open(vocab_file, "w") as fp: - fp.write(json.dumps(vocab_tokens)) - with open(merges_file, "w") as fp: - fp.write("\n".join(merges)) + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(self.vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(self.merges_file, "w") as fp: + fp.write("\n".join(merges)) - input_text = u"lower newer" - output_text = u"lower newer" + def get_tokenizer(self): + return XLMTokenizer.from_pretrained(self.tmpdirname) - create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) + def get_input_output_texts(self): + input_text = u"lower newer" + output_text = u"lower newer" + return input_text, output_text - tokenizer = XLMTokenizer(vocab_file, merges_file) + def test_full_tokenizer(self): + """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) - input_tokens = tokens + [""] - input_bpe_tokens = [14, 15, 20] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 27c6b984ee..08e9e9cb2d 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -19,48 +19,58 @@ import unittest from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) -from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory +from .tokenization_tests_commons import CommonTestCases SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixtures/test_sentencepiece.model') -class XLNetTokenizationTest(unittest.TestCase): +class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): + + tokenizer_class = XLNetTokenizer + + def setUp(self): + super(XLNetTokenizationTest, self).setUp() + + # We have a SentencePiece fixture for testing + tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def get_tokenizer(self): + return XLNetTokenizer.from_pretrained(self.tmpdirname) + + def get_input_output_texts(self): + input_text = u"This is a test" + output_text = u"This is a test" + return input_text, output_text + def test_full_tokenizer(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) - with TemporaryDirectory() as tmpdirname: - tokenizer.save_pretrained(tmpdirname) + tokens = tokenizer.tokenize(u'This is a test') + self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) - input_text = u"This is a test" - output_text = u"This is a test" + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) - create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname) + tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") + self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', + u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', + u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', + SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, [8, 21, 84, 55, 24, 19, 7, 0, + 602, 347, 347, 347, 3, 12, 66, + 46, 72, 80, 6, 0, 4]) - tokens = tokenizer.tokenize(u'This is a test') - self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) - - self.assertListEqual( - tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) - - tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") - self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', - u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', - u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', - SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) - ids = tokenizer.convert_tokens_to_ids(tokens) - self.assertListEqual( - ids, [8, 21, 84, 55, 24, 19, 7, 0, - 602, 347, 347, 347, 3, 12, 66, - 46, 72, 80, 6, 0, 4]) - - back_tokens = tokenizer.convert_ids_to_tokens(ids) - self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', - u'or', u'n', SPIECE_UNDERLINE + u'in', - SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', - SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', - SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', - u'', u'.']) + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', + u'or', u'n', SPIECE_UNDERLINE + u'in', + SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', + SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', + SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', + u'', u'.']) def test_tokenizer_lower(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index d7aeff7c39..9bf18a97d7 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -86,7 +86,7 @@ def whitespace_tokenize(text): class BertTokenizer(PreTrainedTokenizer): r""" Constructs a BertTokenizer. - :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece + :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 1852d74021..a81a5b9235 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -125,42 +125,34 @@ class PreTrainedTokenizer(object): @bos_token.setter def bos_token(self, value): - self.add_tokens([value]) self._bos_token = value @eos_token.setter def eos_token(self, value): - self.add_tokens([value]) self._eos_token = value @unk_token.setter def unk_token(self, value): - self.add_tokens([value]) self._unk_token = value @sep_token.setter def sep_token(self, value): - self.add_tokens([value]) self._sep_token = value @pad_token.setter def pad_token(self, value): - self.add_tokens([value]) self._pad_token = value @cls_token.setter def cls_token(self, value): - self.add_tokens([value]) self._cls_token = value @mask_token.setter def mask_token(self, value): - self.add_tokens([value]) self._mask_token = value @additional_special_tokens.setter def additional_special_tokens(self, value): - self.add_tokens(value) self._additional_special_tokens = value def __init__(self, max_len=None, **kwargs): @@ -179,6 +171,10 @@ class PreTrainedTokenizer(object): for key, value in kwargs.items(): if key in self.SPECIAL_TOKENS_ATTRIBUTES: + if key == 'additional_special_tokens': + assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) + else: + assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) setattr(self, key, value) @@ -415,15 +411,39 @@ class PreTrainedTokenizer(object): Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + Returns: + Number of tokens added to the vocabulary. + + Examples:: + + # Let's see how to add a new classification token to GPT-2 + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + model = GPT2Model.from_pretrained('gpt2') + + special_tokens_dict = {'cls_token': ''} + + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + print('We have added', num_added_toks, 'tokens') + model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + + assert tokenizer.cls_token == '' """ if not special_tokens_dict: return 0 + added_tokens = 0 for key, value in special_tokens_dict.items(): assert key in self.SPECIAL_TOKENS_ATTRIBUTES + if key == 'additional_special_tokens': + assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) + added_tokens += self.add_tokens(value) + else: + assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) + added_tokens += self.add_tokens([value]) logger.info("Assigning %s to the %s key of the tokenizer", value, key) setattr(self, key, value) + return added_tokens def tokenize(self, text, **kwargs): """ Converts a string in a sequence of tokens (string), using the tokenizer.