From 1ac599d90f740ce28f637ad32ff5f59c40cd5a0a Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 15 Nov 2023 20:58:08 +0100 Subject: [PATCH] Fix offload disk for loading derivated model checkpoint into base model (#27253) * fix * style * add test --- src/transformers/modeling_utils.py | 22 ++++++++++------ tests/test_modeling_utils.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fcb51e6a56..57eb08a415 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: folder = None if device_map is not None and is_safetensors: - param_device_map = expand_device_map(device_map, original_loaded_keys) - + param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: archive_file = ( @@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} offload_index = { - p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} for p, f in weight_map.items() - if param_device_map[p] == "disk" + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" } if state_dict is not None: @@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix state_dict_index = None if is_sharded_safetensors: - disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata) + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] else: disk_only_shard_files = [] @@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module: return model -def expand_device_map(device_map, param_names): +def expand_device_map(device_map, param_names, start_prefix): """ Expand a device map to return the correspondance parameter name to device. """ new_device_map = {} + param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)] for module, device in device_map.items(): new_device_map.update( {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} @@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names): return new_device_map -def get_disk_only_shard_files(device_map, sharded_metadata): +def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): """ Returns the list of shard files containing only weights offloaded to disk. """ + + weight_map = { + p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix) + } files_content = collections.defaultdict(list) - for weight_name, filename in sharded_metadata["weight_map"].items(): + for weight_name, filename in weight_map.items(): while len(weight_name) > 0 and weight_name not in device_map: weight_name = ".".join(weight_name.split(".")[:-1]) files_content[filename].append(device_map[weight_name]) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index e457dc07a9..62a639a95f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus): self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu())) + @require_accelerate + @mark.accelerate_tests + @require_torch_accelerator + def test_from_pretrained_disk_offload_derived_to_base_model(self): + derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") + + device_map = { + "wte": 0, + "wpe": 0, + "h.0": "cpu", + "h.1": "cpu", + "h.2": "cpu", + "h.3": "disk", + "h.4": "disk", + "ln_f": 0, + } + with tempfile.TemporaryDirectory() as tmp_dir: + inputs = torch.tensor([[1, 2, 3]]).to(0) + derived_model.save_pretrained(tmp_dir, use_safetensors=True) + base_model = AutoModel.from_pretrained(tmp_dir) + outputs1 = base_model.to(0)(inputs) + + # with disk offload + offload_folder = os.path.join(tmp_dir, "offload") + base_model_with_offload = AutoModel.from_pretrained( + tmp_dir, device_map=device_map, offload_folder=offload_folder + ) + outputs2 = base_model_with_offload(inputs) + self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu())) + + # With state dict temp offload + new_model_with_offload = AutoModel.from_pretrained( + tmp_dir, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=True, + ) + outputs2 = new_model_with_offload(inputs) + self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu())) + def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock()