[Bart/Memory] don't create lm_head (#3323)
* delete lm_head, skips weight tying * Fixed s3
This commit is contained in:
@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
def __init__(self, config: BartConfig):
|
def __init__(self, config: BartConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
# if base_model is None:
|
|
||||||
base_model = BartModel(config)
|
base_model = BartModel(config)
|
||||||
self.model = base_model
|
self.model = base_model
|
||||||
self.lm_head = _make_linear_from_emb(self.model.shared)
|
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
decoder_cached_states=decoder_cached_states,
|
decoder_cached_states=decoder_cached_states,
|
||||||
generation_mode=generation_mode,
|
generation_mode=generation_mode,
|
||||||
)
|
)
|
||||||
lm_logits = self.lm_head(outputs[0])
|
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return self.model.encoder
|
return self.model.encoder
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
|
|||||||
@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_resize_embeddings = False # This requires inputs_dict['input_ids']
|
test_resize_embeddings = True # This requires inputs_dict['input_ids']
|
||||||
|
test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ModelTester(self)
|
self.model_tester = ModelTester(self)
|
||||||
@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
|
||||||
|
|
||||||
|
def test_resize_tokens_embeddings_more(self):
|
||||||
|
config, input_ids, _ = self._get_config_and_data()
|
||||||
|
|
||||||
|
def _get_embs(m):
|
||||||
|
return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
|
||||||
|
|
||||||
|
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
input, output = _get_embs(model)
|
||||||
|
self.assertTrue(torch.eq(input, output).all())
|
||||||
|
new_vocab_size = 45
|
||||||
|
model.resize_token_embeddings(new_vocab_size)
|
||||||
|
input_new, output_new = _get_embs(model)
|
||||||
|
self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
|
||||||
|
self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
|
||||||
|
self.assertTrue(torch.eq(input_new, output_new).all())
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class ModelTesterMixin:
|
|||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
|
test_missing_keys = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
@@ -527,6 +528,8 @@ class ModelTesterMixin:
|
|||||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||||
|
|
||||||
def test_correct_missing_keys(self):
|
def test_correct_missing_keys(self):
|
||||||
|
if not self.test_missing_keys:
|
||||||
|
return
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user