@@ -1184,6 +1184,11 @@ class MBartModel(MBartPreTrainedModel):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
|
def _tie_weights(self):
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
|
||||||
|
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
|||||||
@@ -327,6 +327,43 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
model.generate(input_ids, attention_mask=attention_mask)
|
model.generate(input_ids, attention_mask=attention_mask)
|
||||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
def test_ensure_weights_are_shared(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.tie_word_embeddings = True
|
||||||
|
model = MBartForConditionalGeneration(config)
|
||||||
|
|
||||||
|
# MBart shares four weights.
|
||||||
|
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
|
||||||
|
self.assertEqual(
|
||||||
|
len(
|
||||||
|
{
|
||||||
|
model.get_output_embeddings().weight.data_ptr(),
|
||||||
|
model.get_input_embeddings().weight.data_ptr(),
|
||||||
|
model.base_model.decoder.embed_tokens.weight.data_ptr(),
|
||||||
|
model.base_model.encoder.embed_tokens.weight.data_ptr(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.tie_word_embeddings = False
|
||||||
|
model = MBartForConditionalGeneration(config)
|
||||||
|
|
||||||
|
# MBart shares four weights.
|
||||||
|
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
|
||||||
|
self.assertEqual(
|
||||||
|
len(
|
||||||
|
{
|
||||||
|
model.get_output_embeddings().weight.data_ptr(),
|
||||||
|
model.get_input_embeddings().weight.data_ptr(),
|
||||||
|
model.base_model.decoder.embed_tokens.weight.data_ptr(),
|
||||||
|
model.base_model.encoder.embed_tokens.weight.data_ptr(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
Reference in New Issue
Block a user