[Variant] Make sure variant files are not incorrectly deleted (#21562)
* [Variant] Make sure variant files are not incorrectly deleted * Apply suggestions from code review * fix
This commit is contained in:
committed by
GitHub
parent
51c3f42d8e
commit
b20147a3c8
@@ -3119,6 +3119,27 @@ class ModelUtilsTest(TestCasePlus):
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_checkpoint_variant_save_load(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
|
||||
model.save_pretrained(tmp_dir, variant="v2")
|
||||
# saving will create a variant checkpoint
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||
|
||||
model.save_pretrained(tmp_dir)
|
||||
# saving shouldn't delete variant checkpoints
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||
|
||||
# there should be a normal checkpoint
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_accelerate
|
||||
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||
|
||||
Reference in New Issue
Block a user