[Variant] Make sure variant files are not incorrectly deleted (#21562)

* [Variant] Make sure variant files are not incorrectly deleted

* Apply suggestions from code review

* fix
This commit is contained in:
Patrick von Platen
2023-02-10 16:44:51 +02:00
committed by GitHub
parent 51c3f42d8e
commit b20147a3c8
2 changed files with 27 additions and 0 deletions

View File

@@ -1718,11 +1718,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile("(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shards.keys()
and is_main_process
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

View File

@@ -3119,6 +3119,27 @@ class ModelUtilsTest(TestCasePlus):
)
self.assertIsNotNone(model)
def test_checkpoint_variant_save_load(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
)
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
model.save_pretrained(tmp_dir, variant="v2")
# saving will create a variant checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
model.save_pretrained(tmp_dir)
# saving shouldn't delete variant checkpoints
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
# there should be a normal checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
self.assertIsNotNone(model)
@require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and