fix tests
This commit is contained in:
12
README.md
12
README.md
@@ -60,7 +60,7 @@ This package comprises the following classes that can be imported in Python and
|
|||||||
- `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
- `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
||||||
|
|
||||||
- One optimizer:
|
- One optimizer:
|
||||||
- `BERTAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
|
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
|
||||||
|
|
||||||
- A configuration class:
|
- A configuration class:
|
||||||
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files.
|
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files.
|
||||||
@@ -155,7 +155,7 @@ Here is a detailed documentation of the classes in the package and how to use th
|
|||||||
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
|
||||||
| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
|
| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
|
||||||
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
|
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
|
||||||
| [Optimizer: `BERTAdam`](#Optimizer-BERTAdam) | API of the `BERTAdam` class |
|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
|
||||||
|
|
||||||
### Loading Google AI's pre-trained weigths and PyTorch dump
|
### Loading Google AI's pre-trained weigths and PyTorch dump
|
||||||
|
|
||||||
@@ -294,12 +294,12 @@ and three methods:
|
|||||||
|
|
||||||
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
|
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
|
||||||
|
|
||||||
### Optimizer: `BERTAdam`
|
### Optimizer: `BertAdam`
|
||||||
|
|
||||||
`BERTAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
|
`BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
|
||||||
|
|
||||||
- BERTAdam implements weight decay fix,
|
- BertAdam implements weight decay fix,
|
||||||
- BERTAdam doesn't compensate for bias as in the regular Adam optimizer.
|
- BertAdam doesn't compensate for bias as in the regular Adam optimizer.
|
||||||
|
|
||||||
The optimizer accepts the following arguments:
|
The optimizer accepts the following arguments:
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
|
from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
|
from pytorch_pretrained_bert.modeling import BertModel
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
|
from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification
|
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||||
from pytorch_pretrained_bert.optimization import BERTAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -512,7 +512,7 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||||
]
|
]
|
||||||
optimizer = BERTAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=num_train_steps)
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
|
from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||||
from pytorch_pretrained_bert.optimization import BERTAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -847,7 +847,7 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||||
]
|
]
|
||||||
optimizer = BERTAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=num_train_steps)
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
|||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForQuestionAnswering)
|
BertForSequenceClassification, BertForQuestionAnswering)
|
||||||
from .optimization import BERTAdam
|
from .optimization import BertAdam
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ SCHEDULES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class BERTAdam(Optimizer):
|
class BertAdam(Optimizer):
|
||||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
"""Implements BERT version of Adam algorithm with weight decay fix.
|
||||||
Params:
|
Params:
|
||||||
lr: learning rate
|
lr: learning rate
|
||||||
@@ -73,7 +73,7 @@ class BERTAdam(Optimizer):
|
|||||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||||
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(BERTAdam, self).__init__(params, defaults)
|
super(BertAdam, self).__init__(params, defaults)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lr = []
|
lr = []
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import modeling
|
from pytorch_pretrained_bert import BertConfig, BertModel
|
||||||
|
|
||||||
|
|
||||||
class BertModelTest(unittest.TestCase):
|
class BertModelTest(unittest.TestCase):
|
||||||
@@ -77,8 +77,8 @@ class BertModelTest(unittest.TestCase):
|
|||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
config = modeling.BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
num_attention_heads=self.num_attention_heads,
|
num_attention_heads=self.num_attention_heads,
|
||||||
@@ -90,7 +90,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
type_vocab_size=self.type_vocab_size,
|
type_vocab_size=self.type_vocab_size,
|
||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
model = modeling.BertModel(config=config)
|
model = BertModel(config=config)
|
||||||
|
|
||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
def test_config_to_json_string(self):
|
def test_config_to_json_string(self):
|
||||||
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
|
||||||
obj = json.loads(config.to_json_string())
|
obj = json.loads(config.to_json_string())
|
||||||
self.assertEqual(obj["vocab_size"], 99)
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
self.assertEqual(obj["hidden_size"], 37)
|
self.assertEqual(obj["hidden_size"], 37)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import optimization
|
from pytorch_pretrained_bert import BertAdam
|
||||||
|
|
||||||
class OptimizationTest(unittest.TestCase):
|
class OptimizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
target = torch.tensor([0.4, 0.2, -0.5])
|
target = torch.tensor([0.4, 0.2, -0.5])
|
||||||
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
|
||||||
# No warmup, constant schedule, no gradient clipping
|
# No warmup, constant schedule, no gradient clipping
|
||||||
optimizer = optimization.BERTAdam(params=[w], lr=2e-1,
|
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||||
weight_decay_rate=0.0,
|
weight_decay_rate=0.0,
|
||||||
max_grad_norm=-1)
|
max_grad_norm=-1)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import tokenization
|
from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
|
||||||
|
_is_whitespace, _is_control, _is_punctuation)
|
||||||
|
|
||||||
|
|
||||||
class TokenizationTest(unittest.TestCase):
|
class TokenizationTest(unittest.TestCase):
|
||||||
@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
vocab_file = vocab_writer.name
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
tokenizer = tokenization.BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
os.remove(vocab_file)
|
os.remove(vocab_file)
|
||||||
|
|
||||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||||
@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = tokenization.BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
||||||
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||||
|
|
||||||
def test_basic_tokenizer_lower(self):
|
def test_basic_tokenizer_lower(self):
|
||||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
tokenizer = BasicTokenizer(do_lower_case=True)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||||
|
|
||||||
def test_basic_tokenizer_no_lower(self):
|
def test_basic_tokenizer_no_lower(self):
|
||||||
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
tokenizer = BasicTokenizer(do_lower_case=False)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
vocab = {}
|
vocab = {}
|
||||||
for (i, token) in enumerate(vocab_tokens):
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
vocab[token] = i
|
vocab[token] = i
|
||||||
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
tokenizer = WordpieceTokenizer(vocab=vocab)
|
||||||
|
|
||||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||||
|
|
||||||
@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||||
|
|
||||||
def test_convert_tokens_to_ids(self):
|
|
||||||
vocab_tokens = [
|
|
||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
|
||||||
"##ing"
|
|
||||||
]
|
|
||||||
|
|
||||||
vocab = {}
|
|
||||||
for (i, token) in enumerate(vocab_tokens):
|
|
||||||
vocab[token] = i
|
|
||||||
|
|
||||||
self.assertListEqual(
|
|
||||||
tokenization.convert_tokens_to_ids(
|
|
||||||
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
|
||||||
|
|
||||||
def test_is_whitespace(self):
|
def test_is_whitespace(self):
|
||||||
self.assertTrue(tokenization._is_whitespace(u" "))
|
self.assertTrue(_is_whitespace(u" "))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
self.assertTrue(_is_whitespace(u"\t"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
self.assertTrue(_is_whitespace(u"\r"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
self.assertTrue(_is_whitespace(u"\n"))
|
||||||
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
self.assertTrue(_is_whitespace(u"\u00A0"))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_whitespace(u"A"))
|
self.assertFalse(_is_whitespace(u"A"))
|
||||||
self.assertFalse(tokenization._is_whitespace(u"-"))
|
self.assertFalse(_is_whitespace(u"-"))
|
||||||
|
|
||||||
def test_is_control(self):
|
def test_is_control(self):
|
||||||
self.assertTrue(tokenization._is_control(u"\u0005"))
|
self.assertTrue(_is_control(u"\u0005"))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_control(u"A"))
|
self.assertFalse(_is_control(u"A"))
|
||||||
self.assertFalse(tokenization._is_control(u" "))
|
self.assertFalse(_is_control(u" "))
|
||||||
self.assertFalse(tokenization._is_control(u"\t"))
|
self.assertFalse(_is_control(u"\t"))
|
||||||
self.assertFalse(tokenization._is_control(u"\r"))
|
self.assertFalse(_is_control(u"\r"))
|
||||||
|
|
||||||
def test_is_punctuation(self):
|
def test_is_punctuation(self):
|
||||||
self.assertTrue(tokenization._is_punctuation(u"-"))
|
self.assertTrue(_is_punctuation(u"-"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"$"))
|
self.assertTrue(_is_punctuation(u"$"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"`"))
|
self.assertTrue(_is_punctuation(u"`"))
|
||||||
self.assertTrue(tokenization._is_punctuation(u"."))
|
self.assertTrue(_is_punctuation(u"."))
|
||||||
|
|
||||||
self.assertFalse(tokenization._is_punctuation(u"A"))
|
self.assertFalse(_is_punctuation(u"A"))
|
||||||
self.assertFalse(tokenization._is_punctuation(u" "))
|
self.assertFalse(_is_punctuation(u" "))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user