Fix loading with only state dict and low_cpu_mem_usage = True (#35217)
* fix loading with only state dict and config * style * add tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus):
|
||||
new_model.generate(random_ids, max_new_tokens=3)
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
def test_load_model_with_state_dict_only(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
state_dict = model.state_dict()
|
||||
config = model.config
|
||||
|
||||
model_loaded = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
)
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
state_dict = model.state_dict()
|
||||
config = model.config
|
||||
|
||||
model_loaded = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user