[Bart/Memory] don't create lm_head (#3323)
* delete lm_head, skips weight tying * Fixed s3
This commit is contained in:
@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False # This requires inputs_dict['input_ids']
|
||||
test_resize_embeddings = True # This requires inputs_dict['input_ids']
|
||||
test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||
|
||||
def test_resize_tokens_embeddings_more(self):
|
||||
config, input_ids, _ = self._get_config_and_data()
|
||||
|
||||
def _get_embs(m):
|
||||
return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
|
||||
|
||||
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||
input, output = _get_embs(model)
|
||||
self.assertTrue(torch.eq(input, output).all())
|
||||
new_vocab_size = 45
|
||||
model.resize_token_embeddings(new_vocab_size)
|
||||
input_new, output_new = _get_embs(model)
|
||||
self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
|
||||
self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
|
||||
self.assertTrue(torch.eq(input_new, output_new).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
||||
Reference in New Issue
Block a user