Fix layer reference loss + previous attempted fix

This commit is contained in:
LysandreJik
2019-07-11 22:29:55 -04:00
parent 6c2ee16c04
commit 3fbceed8d2
4 changed files with 8 additions and 8 deletions

View File

@@ -541,8 +541,8 @@ class ModelUtilsTest(unittest.TestCase):
model.resize_token_embeddings(config.vocab_size + 10)
decoding.weight.data.mul_(20)
# Check that the embedding layer and decoding layer are the same in size and in value
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
self.assertTrue(check_same_values(embeddings, decoding))
self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
if __name__ == "__main__":