Test suite testing the tie_weights function as well as the resize_token_embeddings function.
Patched an issue relating to the tied weights I had introduced with the TorchScript addition. Byte order mark management in TSV glue reading.
This commit is contained in:
@@ -29,6 +29,7 @@ import torch
|
||||
|
||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@@ -470,6 +471,79 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
model = BertModel.from_pretrained(model_name)
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
cloned_embeddings = model.embeddings.word_embeddings.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0])
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
models_equal = True
|
||||
for p1, p2 in zip(cloned_embeddings, model.embeddings.word_embeddings.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def check_same_values(layer_1, layer_2):
|
||||
equal = True
|
||||
for p1, p2 in zip(layer_1.weight, layer_2.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
equal = False
|
||||
return equal
|
||||
|
||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
model = GPT2LMHeadModel.from_pretrained(model_name)
|
||||
|
||||
# Get the embeddings and decoding layer
|
||||
embeddings = model.transformer.wte
|
||||
decoding = model.lm_head
|
||||
|
||||
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||
self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after modification, they remain the same.
|
||||
embeddings.weight.data.div_(2)
|
||||
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||
self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after modification, they remain the same.
|
||||
decoding.weight.data.div_(4)
|
||||
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||
self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after resize they remain tied.
|
||||
model.resize_token_embeddings(config.vocab_size + 10)
|
||||
decoding.weight.data.mul_(20)
|
||||
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||
self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user