Test suite testing the tie_weights function as well as the resize_token_embeddings function.
Patched an issue relating to the tied weights I had introduced with the TorchScript addition. Byte order mark management in TSV glue reading.
This commit is contained in:
@@ -78,7 +78,7 @@ class DataProcessor(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _read_tsv(cls, input_file, quotechar=None):
|
def _read_tsv(cls, input_file, quotechar=None):
|
||||||
"""Reads a tab separated value file."""
|
"""Reads a tab separated value file."""
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8-sig") as f:
|
||||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
|
|||||||
@@ -762,7 +762,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||||
next_sentence_label=None, head_mask=None):
|
next_sentence_label=None, head_mask=None):
|
||||||
@@ -868,7 +868,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -566,7 +566,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head.weight = input_embeddings # Tied weights
|
self.lm_head = self.transformer.wte # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@@ -662,7 +662,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head.weight = input_embeddings # Tied weights
|
self.lm_head = self.transformer.wte # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||||
position_ids=None, past=None, head_mask=None):
|
position_ids=None, past=None, head_mask=None):
|
||||||
|
|||||||
@@ -587,7 +587,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head.weight = input_embeddings # Tied weights
|
self.lm_head = self.transformer.tokens_embed # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
@@ -700,7 +700,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
if self.config.torchscript:
|
if self.config.torchscript:
|
||||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||||
else:
|
else:
|
||||||
self.lm_head.weight = input_embeddings # Tied weights
|
self.lm_head = self.transformer.tokens_embed # Tied weights
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import torch
|
|||||||
|
|
||||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
from pytorch_transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
@@ -470,6 +471,79 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
self.assertEqual(model.config.output_hidden_states, True)
|
self.assertEqual(model.config.output_hidden_states, True)
|
||||||
self.assertEqual(model.config, config)
|
self.assertEqual(model.config, config)
|
||||||
|
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
config = BertConfig.from_pretrained(model_name)
|
||||||
|
model = BertModel.from_pretrained(model_name)
|
||||||
|
|
||||||
|
model_vocab_size = config.vocab_size
|
||||||
|
# Retrieve the embeddings and clone theme
|
||||||
|
cloned_embeddings = model.embeddings.word_embeddings.weight.clone()
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size)
|
||||||
|
self.assertEqual(model.config.vocab_size, model_vocab_size)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0])
|
||||||
|
|
||||||
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(cloned_embeddings, model.embeddings.word_embeddings.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
def test_tie_model_weights(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
def check_same_values(layer_1, layer_2):
|
||||||
|
equal = True
|
||||||
|
for p1, p2 in zip(layer_1.weight, layer_2.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
equal = False
|
||||||
|
return equal
|
||||||
|
|
||||||
|
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
config = GPT2Config.from_pretrained(model_name)
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Get the embeddings and decoding layer
|
||||||
|
embeddings = model.transformer.wte
|
||||||
|
decoding = model.lm_head
|
||||||
|
|
||||||
|
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||||
|
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||||
|
self.assertTrue(check_same_values(embeddings, decoding))
|
||||||
|
|
||||||
|
# Check that after modification, they remain the same.
|
||||||
|
embeddings.weight.data.div_(2)
|
||||||
|
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||||
|
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||||
|
self.assertTrue(check_same_values(embeddings, decoding))
|
||||||
|
|
||||||
|
# Check that after modification, they remain the same.
|
||||||
|
decoding.weight.data.div_(4)
|
||||||
|
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||||
|
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||||
|
self.assertTrue(check_same_values(embeddings, decoding))
|
||||||
|
|
||||||
|
# Check that after resize they remain tied.
|
||||||
|
model.resize_token_embeddings(config.vocab_size + 10)
|
||||||
|
decoding.weight.data.mul_(20)
|
||||||
|
# Check that the embedding layer and decoding layer are the same in size and in value
|
||||||
|
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
|
||||||
|
self.assertTrue(check_same_values(embeddings, decoding))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user