[Bart/Memory] don't create lm_head (#3323)
* delete lm_head, skips weight tying * Fixed s3
This commit is contained in:
@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
# if base_model is None:
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.lm_head = _make_linear_from_emb(self.model.shared)
|
||||
|
||||
def tie_weights(self):
|
||||
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return self.model.encoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
||||
@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False # This requires inputs_dict['input_ids']
|
||||
test_resize_embeddings = True # This requires inputs_dict['input_ids']
|
||||
test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||
|
||||
def test_resize_tokens_embeddings_more(self):
|
||||
config, input_ids, _ = self._get_config_and_data()
|
||||
|
||||
def _get_embs(m):
|
||||
return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
|
||||
|
||||
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||
input, output = _get_embs(model)
|
||||
self.assertTrue(torch.eq(input, output).all())
|
||||
new_vocab_size = 45
|
||||
model.resize_token_embeddings(new_vocab_size)
|
||||
input_new, output_new = _get_embs(model)
|
||||
self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
|
||||
self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
|
||||
self.assertTrue(torch.eq(input_new, output_new).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
||||
@@ -58,6 +58,7 @@ class ModelTesterMixin:
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
test_missing_keys = True
|
||||
is_encoder_decoder = False
|
||||
|
||||
def test_save_load(self):
|
||||
@@ -527,6 +528,8 @@ class ModelTesterMixin:
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
return
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user