From b20147a3c8b981854c4115ce79d7b5c9bbc9aac9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Feb 2023 16:44:51 +0200 Subject: [PATCH] [Variant] Make sure variant files are not incorrectly deleted (#21562) * [Variant] Make sure variant files are not incorrectly deleted * Apply suggestions from code review * fix --- src/transformers/modeling_utils.py | 6 ++++++ tests/test_modeling_common.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 160cf814b8..f331eecddf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1718,11 +1718,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # If we have a shard file that is not going to be replaced, we delete it, but only from the main process # in distributed settings to avoid race conditions. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile("(.*?)-\d{5}-of-\d{5}") + if ( filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards.keys() and is_main_process + and reg.fullmatch(filename_no_suffix) is not None ): os.remove(full_filename) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 26711f660d..41217e266e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3119,6 +3119,27 @@ class ModelUtilsTest(TestCasePlus): ) self.assertIsNotNone(model) + def test_checkpoint_variant_save_load(self): + with tempfile.TemporaryDirectory() as tmp_dir: + model = BertModel.from_pretrained( + "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2" + ) + weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) + + model.save_pretrained(tmp_dir, variant="v2") + # saving will create a variant checkpoint + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) + + model.save_pretrained(tmp_dir) + # saving shouldn't delete variant checkpoints + weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) + + # there should be a normal checkpoint + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME))) + + self.assertIsNotNone(model) + @require_accelerate def test_from_pretrained_low_cpu_mem_usage_functional(self): # test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and