From 757750d6f6a5883f40815a8f80ff77eb8b0e82df Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 17 Nov 2018 11:58:14 +0100 Subject: [PATCH] fix tests --- README.md | 12 ++--- examples/extract_features.py | 2 +- examples/run_classifier.py | 6 +-- examples/run_squad.py | 6 +-- pytorch_pretrained_bert/__init__.py | 2 +- pytorch_pretrained_bert/optimization.py | 4 +- tests/modeling_test.py | 10 ++-- tests/optimization_test.py | 4 +- tests/tokenization_test.py | 63 ++++++++++--------------- 9 files changed, 48 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 06d5a18fba..f42ae74137 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ This package comprises the following classes that can be imported in Python and - `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. - One optimizer: - - `BERTAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. + - `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. - A configuration class: - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files. @@ -155,7 +155,7 @@ Here is a detailed documentation of the classes in the package and how to use th | [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance | | [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` | | [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class| -| [Optimizer: `BERTAdam`](#Optimizer-BERTAdam) | API of the `BERTAdam` class | +| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class | ### Loading Google AI's pre-trained weigths and PyTorch dump @@ -294,12 +294,12 @@ and three methods: Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing. -### Optimizer: `BERTAdam` +### Optimizer: `BertAdam` -`BERTAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following: +`BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following: -- BERTAdam implements weight decay fix, -- BERTAdam doesn't compensate for bias as in the regular Adam optimizer. +- BertAdam implements weight decay fix, +- BertAdam doesn't compensate for bias as in the regular Adam optimizer. The optimizer accepts the following arguments: diff --git a/examples/extract_features.py b/examples/extract_features.py index fce90dffa2..84e9a8d6ab 100644 --- a/examples/extract_features.py +++ b/examples/extract_features.py @@ -29,7 +29,7 @@ from torch.utils.data import TensorDataset, DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer -from pytorch_pretrained_bert.modeling import BertConfig, BertModel +from pytorch_pretrained_bert.modeling import BertModel logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 7ad2492b90..0dc5347e2e 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -31,8 +31,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer -from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification -from pytorch_pretrained_bert.optimization import BERTAdam +from pytorch_pretrained_bert.modeling import BertForSequenceClassification +from pytorch_pretrained_bert.optimization import BertAdam logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -512,7 +512,7 @@ def main(): {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} ] - optimizer = BERTAdam(optimizer_grouped_parameters, + optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) diff --git a/examples/run_squad.py b/examples/run_squad.py index c9acbbac7e..45dbee29c2 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -33,8 +33,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer -from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering -from pytorch_pretrained_bert.optimization import BERTAdam +from pytorch_pretrained_bert.modeling import BertForQuestionAnswering +from pytorch_pretrained_bert.optimization import BertAdam logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -847,7 +847,7 @@ def main(): {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} ] - optimizer = BERTAdam(optimizer_grouped_parameters, + optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index 066bd7830e..e12e48c2b9 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForQuestionAnswering) -from .optimization import BERTAdam +from .optimization import BertAdam diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index a5ac283b07..4266a8f83b 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -41,7 +41,7 @@ SCHEDULES = { } -class BERTAdam(Optimizer): +class BertAdam(Optimizer): """Implements BERT version of Adam algorithm with weight decay fix. Params: lr: learning rate @@ -73,7 +73,7 @@ class BERTAdam(Optimizer): defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, max_grad_norm=max_grad_norm) - super(BERTAdam, self).__init__(params, defaults) + super(BertAdam, self).__init__(params, defaults) def get_lr(self): lr = [] diff --git a/tests/modeling_test.py b/tests/modeling_test.py index d3d937a06e..48d56826f8 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -22,7 +22,7 @@ import random import torch -import modeling +from pytorch_pretrained_bert import BertConfig, BertModel class BertModelTest(unittest.TestCase): @@ -77,8 +77,8 @@ class BertModelTest(unittest.TestCase): if self.use_token_type_ids: token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - config = modeling.BertConfig( - vocab_size=self.vocab_size, + config = BertConfig( + vocab_size_or_config_json_file=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, @@ -90,7 +90,7 @@ class BertModelTest(unittest.TestCase): type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range) - model = modeling.BertModel(config=config) + model = BertModel(config=config) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) @@ -112,7 +112,7 @@ class BertModelTest(unittest.TestCase): self.run_tester(BertModelTest.BertModelTester(self)) def test_config_to_json_string(self): - config = modeling.BertConfig(vocab_size=99, hidden_size=37) + config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) obj = json.loads(config.to_json_string()) self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["hidden_size"], 37) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 1675c8b496..1c010750ae 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -20,7 +20,7 @@ import unittest import torch -import optimization +from pytorch_pretrained_bert import BertAdam class OptimizationTest(unittest.TestCase): @@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase): target = torch.tensor([0.4, 0.2, -0.5]) criterion = torch.nn.MSELoss(reduction='elementwise_mean') # No warmup, constant schedule, no gradient clipping - optimizer = optimization.BERTAdam(params=[w], lr=2e-1, + optimizer = BertAdam(params=[w], lr=2e-1, weight_decay_rate=0.0, max_grad_norm=-1) for _ in range(100): diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index fda1cdb243..f541a620e8 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -19,7 +19,8 @@ from __future__ import print_function import os import unittest -import tokenization +from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer, + _is_whitespace, _is_control, _is_punctuation) class TokenizationTest(unittest.TestCase): @@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase): vocab_file = vocab_writer.name - tokenizer = tokenization.BertTokenizer(vocab_file) + tokenizer = BertTokenizer(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") @@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase): tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) def test_chinese(self): - tokenizer = tokenization.BasicTokenizer() + tokenizer = BasicTokenizer() self.assertListEqual( tokenizer.tokenize(u"ah\u535A\u63A8zz"), [u"ah", u"\u535A", u"\u63A8", u"zz"]) def test_basic_tokenizer_lower(self): - tokenizer = tokenization.BasicTokenizer(do_lower_case=True) + tokenizer = BasicTokenizer(do_lower_case=True) self.assertListEqual( tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), @@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase): self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) def test_basic_tokenizer_no_lower(self): - tokenizer = tokenization.BasicTokenizer(do_lower_case=False) + tokenizer = BasicTokenizer(do_lower_case=False) self.assertListEqual( tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), @@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase): vocab = {} for (i, token) in enumerate(vocab_tokens): vocab[token] = i - tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) + tokenizer = WordpieceTokenizer(vocab=vocab) self.assertListEqual(tokenizer.tokenize(""), []) @@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) - def test_convert_tokens_to_ids(self): - vocab_tokens = [ - "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", - "##ing" - ] - - vocab = {} - for (i, token) in enumerate(vocab_tokens): - vocab[token] = i - - self.assertListEqual( - tokenization.convert_tokens_to_ids( - vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) - def test_is_whitespace(self): - self.assertTrue(tokenization._is_whitespace(u" ")) - self.assertTrue(tokenization._is_whitespace(u"\t")) - self.assertTrue(tokenization._is_whitespace(u"\r")) - self.assertTrue(tokenization._is_whitespace(u"\n")) - self.assertTrue(tokenization._is_whitespace(u"\u00A0")) + self.assertTrue(_is_whitespace(u" ")) + self.assertTrue(_is_whitespace(u"\t")) + self.assertTrue(_is_whitespace(u"\r")) + self.assertTrue(_is_whitespace(u"\n")) + self.assertTrue(_is_whitespace(u"\u00A0")) - self.assertFalse(tokenization._is_whitespace(u"A")) - self.assertFalse(tokenization._is_whitespace(u"-")) + self.assertFalse(_is_whitespace(u"A")) + self.assertFalse(_is_whitespace(u"-")) def test_is_control(self): - self.assertTrue(tokenization._is_control(u"\u0005")) + self.assertTrue(_is_control(u"\u0005")) - self.assertFalse(tokenization._is_control(u"A")) - self.assertFalse(tokenization._is_control(u" ")) - self.assertFalse(tokenization._is_control(u"\t")) - self.assertFalse(tokenization._is_control(u"\r")) + self.assertFalse(_is_control(u"A")) + self.assertFalse(_is_control(u" ")) + self.assertFalse(_is_control(u"\t")) + self.assertFalse(_is_control(u"\r")) def test_is_punctuation(self): - self.assertTrue(tokenization._is_punctuation(u"-")) - self.assertTrue(tokenization._is_punctuation(u"$")) - self.assertTrue(tokenization._is_punctuation(u"`")) - self.assertTrue(tokenization._is_punctuation(u".")) + self.assertTrue(_is_punctuation(u"-")) + self.assertTrue(_is_punctuation(u"$")) + self.assertTrue(_is_punctuation(u"`")) + self.assertTrue(_is_punctuation(u".")) - self.assertFalse(tokenization._is_punctuation(u"A")) - self.assertFalse(tokenization._is_punctuation(u" ")) + self.assertFalse(_is_punctuation(u"A")) + self.assertFalse(_is_punctuation(u" ")) if __name__ == '__main__':