Fix loading with only state dict and low_cpu_mem_usage = True (#35217)

* fix loading with only state dict and config

* style

* add tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Marc Sun
2024-12-18 09:54:32 +01:00
committed by GitHub
parent 0531d7513b
commit 1eee1cedfd
2 changed files with 26 additions and 3 deletions

View File

@@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus):
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
def test_load_model_with_state_dict_only(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config
model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
self.assertTrue(check_models_equal(model, model_loaded))
def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config
model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
)
self.assertTrue(check_models_equal(model, model_loaded))
@slow
@require_torch