[Bart/Memory] don't create lm_head (#3323)

* delete lm_head, skips weight tying
* Fixed s3
This commit is contained in:
Sam Shleifer
2020-03-26 18:40:39 -04:00
committed by GitHub
parent 5ad2ea06af
commit 39371ee454
3 changed files with 23 additions and 8 deletions

View File

@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_head_masking = False
test_resize_embeddings = False # This requires inputs_dict['input_ids']
test_resize_embeddings = True # This requires inputs_dict['input_ids']
test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict
def setUp(self):
self.model_tester = ModelTester(self)
@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
)
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data()
def _get_embs(m):
return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
model = BartForConditionalGeneration(config).eval().to(torch_device)
input, output = _get_embs(model)
self.assertTrue(torch.eq(input, output).all())
new_vocab_size = 45
model.resize_token_embeddings(new_vocab_size)
input_new, output_new = _get_embs(model)
self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
self.assertTrue(torch.eq(input_new, output_new).all())
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""