Persist embedding type of BART and mBART models after resize (#32242)

* fix: persist embedding type of MBartConditonalGeneration after resize

* fix: persist embedding type of BartConditonalGeneration after resize
This commit is contained in:
Abdi
2024-08-05 21:15:36 +08:00
committed by GitHub
parent f5f1e52f6c
commit baf7e5c927
4 changed files with 28 additions and 2 deletions

View File

@@ -375,6 +375,18 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_load_save_without_tied_weights(self):
pass
def test_resize_embeddings_persists_embeddings_type(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.scale_embedding = True
model = MBartForConditionalGeneration(config)
old_type = type(model.model.decoder.embed_tokens)
model.resize_token_embeddings(new_num_tokens=config.vocab_size)
new_type = type(model.model.decoder.embed_tokens)
self.assertIs(old_type, new_type)
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""