fix_mbart_tied_weights (#26422)

* fix_mbart_tied_weights

* add test
This commit is contained in:
Marc Sun
2023-09-28 15:08:35 +02:00
committed by GitHub
parent 216dff7549
commit 5e11d72d4d
2 changed files with 42 additions and 0 deletions

View File

@@ -327,6 +327,43 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.tie_word_embeddings = True
model = MBartForConditionalGeneration(config)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
1,
)
config.tie_word_embeddings = False
model = MBartForConditionalGeneration(config)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
2,
)
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""