Clean load keys (#24505)
* Preliminary work on some models * Fix test load missing and make sure nonpersistent buffers are tested * Always ignore nonpersistent buffers if in state_dict * Treat models * More models * Treat remaining models * Fix quality * Fix tests * Remove draft * This test is not needed anymore * Fix copies * Fix last test * Newly added models * Fix last tests * Address review comments
This commit is contained in:
@@ -500,8 +500,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
|
||||
for i in range(1, 5):
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
@@ -546,8 +546,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
|
||||
for i in range(1, 5):
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user