diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index ac1764de8b..a35fae9ae9 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel): def __init__(self, config: BartConfig): super().__init__(config) - # if base_model is None: base_model = BartModel(config) self.model = base_model - self.lm_head = _make_linear_from_emb(self.model.shared) - - def tie_weights(self): - pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same. @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) def forward( @@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel): decoder_cached_states=decoder_cached_states, generation_mode=generation_mode, ) - lm_logits = self.lm_head(outputs[0]) + lm_logits = F.linear(outputs[0], self.model.shared.weight) outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here if lm_labels is not None: loss_fct = nn.CrossEntropyLoss() @@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel): return self.model.encoder def get_output_embeddings(self): - return self.lm_head + return _make_linear_from_emb(self.model.shared) # make it on the fly @add_start_docstrings( diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index d064f0f780..4a807286f3 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = False test_head_masking = False - test_resize_embeddings = False # This requires inputs_dict['input_ids'] + test_resize_embeddings = True # This requires inputs_dict['input_ids'] + test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict def setUp(self): self.model_tester = ModelTester(self) @@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase): ) self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all()) + def test_resize_tokens_embeddings_more(self): + config, input_ids, _ = self._get_config_and_data() + + def _get_embs(m): + return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone()) + + model = BartForConditionalGeneration(config).eval().to(torch_device) + input, output = _get_embs(model) + self.assertTrue(torch.eq(input, output).all()) + new_vocab_size = 45 + model.resize_token_embeddings(new_vocab_size) + input_new, output_new = _get_embs(model) + self.assertEqual(input_new.shape, (new_vocab_size, config.d_model)) + self.assertEqual(output_new.shape, (new_vocab_size, config.d_model)) + self.assertTrue(torch.eq(input_new, output_new).all()) + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a0d0fe402c..b284ee6ec2 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -58,6 +58,7 @@ class ModelTesterMixin: test_pruning = True test_resize_embeddings = True test_head_masking = True + test_missing_keys = True is_encoder_decoder = False def test_save_load(self): @@ -527,6 +528,8 @@ class ModelTesterMixin: self.assertTrue(x is None or isinstance(x, torch.nn.Linear)) def test_correct_missing_keys(self): + if not self.test_missing_keys: + return config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: