From 6c2ee16c0418a09c13cd59bf285f56feb001d3b5 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 11 Jul 2019 22:09:16 -0400 Subject: [PATCH] Test suite testing the tie_weights function as well as the resize_token_embeddings function. Patched an issue relating to the tied weights I had introduced with the TorchScript addition. Byte order mark management in TSV glue reading. --- examples/utils_glue.py | 2 +- pytorch_transformers/modeling_bert.py | 4 +- pytorch_transformers/modeling_gpt2.py | 4 +- pytorch_transformers/modeling_openai.py | 4 +- .../tests/modeling_common_test.py | 74 +++++++++++++++++++ 5 files changed, 81 insertions(+), 7 deletions(-) diff --git a/examples/utils_glue.py b/examples/utils_glue.py index 5ad36abf10..bba9a901a8 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -78,7 +78,7 @@ class DataProcessor(object): @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" - with open(input_file, "r", encoding="utf-8") as f: + with open(input_file, "r", encoding="utf-8-sig") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index d88c57bb79..23b2e76ec7 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -762,7 +762,7 @@ class BertForPreTraining(BertPreTrainedModel): if self.config.torchscript: self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone()) else: - self.cls.predictions.decoder.weight = input_embeddings # Tied weights + self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None): @@ -868,7 +868,7 @@ class BertForMaskedLM(BertPreTrainedModel): if self.config.torchscript: self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone()) else: - self.cls.predictions.decoder.weight = input_embeddings # Tied weights + self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): """ diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 06f933147f..5823bad322 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -566,7 +566,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): if self.config.torchscript: self.lm_head.weight = nn.Parameter(input_embeddings.clone()) else: - self.lm_head.weight = input_embeddings # Tied weights + self.lm_head = self.transformer.wte # Tied weights def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None): """ @@ -662,7 +662,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): if self.config.torchscript: self.lm_head.weight = nn.Parameter(input_embeddings.clone()) else: - self.lm_head.weight = input_embeddings # Tied weights + self.lm_head = self.transformer.wte # Tied weights def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None, head_mask=None): diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index ebf1035d21..47a07e77b3 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -587,7 +587,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): if self.config.torchscript: self.lm_head.weight = nn.Parameter(input_embeddings.clone()) else: - self.lm_head.weight = input_embeddings # Tied weights + self.lm_head = self.transformer.tokens_embed # Tied weights def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None): """ @@ -700,7 +700,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): if self.config.torchscript: self.lm_head.weight = nn.Parameter(input_embeddings.clone()) else: - self.lm_head.weight = input_embeddings # Tied weights + self.lm_head = self.transformer.tokens_embed # Tied weights def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, head_mask=None): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 98849216fa..9f14c181bb 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -29,6 +29,7 @@ import torch from pytorch_transformers import PretrainedConfig, PreTrainedModel from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP def _config_zero_init(config): @@ -470,6 +471,79 @@ class ModelUtilsTest(unittest.TestCase): self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config, config) + def test_resize_tokens_embeddings(self): + logging.basicConfig(level=logging.INFO) + + + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = BertConfig.from_pretrained(model_name) + model = BertModel.from_pretrained(model_name) + + model_vocab_size = config.vocab_size + # Retrieve the embeddings and clone theme + cloned_embeddings = model.embeddings.word_embeddings.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0] + 10) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model.resize_token_embeddings(model_vocab_size) + self.assertEqual(model.config.vocab_size, model_vocab_size) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0]) + + # Check that adding and removing tokens has not modified the first part of the embedding matrix. + models_equal = True + for p1, p2 in zip(cloned_embeddings, model.embeddings.word_embeddings.weight): + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_tie_model_weights(self): + logging.basicConfig(level=logging.INFO) + + def check_same_values(layer_1, layer_2): + equal = True + for p1, p2 in zip(layer_1.weight, layer_2.weight): + if p1.data.ne(p2.data).sum() > 0: + equal = False + return equal + + for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = GPT2Config.from_pretrained(model_name) + model = GPT2LMHeadModel.from_pretrained(model_name) + + # Get the embeddings and decoding layer + embeddings = model.transformer.wte + decoding = model.lm_head + + # Check that the embedding layer and decoding layer are the same in size and in value + self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + self.assertTrue(check_same_values(embeddings, decoding)) + + # Check that after modification, they remain the same. + embeddings.weight.data.div_(2) + # Check that the embedding layer and decoding layer are the same in size and in value + self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + self.assertTrue(check_same_values(embeddings, decoding)) + + # Check that after modification, they remain the same. + decoding.weight.data.div_(4) + # Check that the embedding layer and decoding layer are the same in size and in value + self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + self.assertTrue(check_same_values(embeddings, decoding)) + + # Check that after resize they remain tied. + model.resize_token_embeddings(config.vocab_size + 10) + decoding.weight.data.mul_(20) + # Check that the embedding layer and decoding layer are the same in size and in value + self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + self.assertTrue(check_same_values(embeddings, decoding)) + if __name__ == "__main__": unittest.main()