Fix offload disk for loading derivated model checkpoint into base model (#27253)
* fix * style * add test
This commit is contained in:
@@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
folder = None
|
folder = None
|
||||||
if device_map is not None and is_safetensors:
|
if device_map is not None and is_safetensors:
|
||||||
param_device_map = expand_device_map(device_map, original_loaded_keys)
|
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||||
|
|
||||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||||
if sharded_metadata is None:
|
if sharded_metadata is None:
|
||||||
archive_file = (
|
archive_file = (
|
||||||
@@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||||
offload_index = {
|
offload_index = {
|
||||||
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||||
for p, f in weight_map.items()
|
for p, f in weight_map.items()
|
||||||
if param_device_map[p] == "disk"
|
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||||
}
|
}
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
@@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
state_dict_index = None
|
state_dict_index = None
|
||||||
|
|
||||||
if is_sharded_safetensors:
|
if is_sharded_safetensors:
|
||||||
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
|
disk_only_shard_files = get_disk_only_shard_files(
|
||||||
|
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
||||||
|
)
|
||||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||||
else:
|
else:
|
||||||
disk_only_shard_files = []
|
disk_only_shard_files = []
|
||||||
@@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def expand_device_map(device_map, param_names):
|
def expand_device_map(device_map, param_names, start_prefix):
|
||||||
"""
|
"""
|
||||||
Expand a device map to return the correspondance parameter name to device.
|
Expand a device map to return the correspondance parameter name to device.
|
||||||
"""
|
"""
|
||||||
new_device_map = {}
|
new_device_map = {}
|
||||||
|
param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
|
||||||
for module, device in device_map.items():
|
for module, device in device_map.items():
|
||||||
new_device_map.update(
|
new_device_map.update(
|
||||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||||
@@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names):
|
|||||||
return new_device_map
|
return new_device_map
|
||||||
|
|
||||||
|
|
||||||
def get_disk_only_shard_files(device_map, sharded_metadata):
|
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||||
"""
|
"""
|
||||||
Returns the list of shard files containing only weights offloaded to disk.
|
Returns the list of shard files containing only weights offloaded to disk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
weight_map = {
|
||||||
|
p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
|
||||||
|
}
|
||||||
files_content = collections.defaultdict(list)
|
files_content = collections.defaultdict(list)
|
||||||
for weight_name, filename in sharded_metadata["weight_map"].items():
|
for weight_name, filename in weight_map.items():
|
||||||
while len(weight_name) > 0 and weight_name not in device_map:
|
while len(weight_name) > 0 and weight_name not in device_map:
|
||||||
weight_name = ".".join(weight_name.split(".")[:-1])
|
weight_name = ".".join(weight_name.split(".")[:-1])
|
||||||
files_content[filename].append(device_map[weight_name])
|
files_content[filename].append(device_map[weight_name])
|
||||||
|
|||||||
@@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_from_pretrained_disk_offload_derived_to_base_model(self):
|
||||||
|
derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
|
||||||
|
device_map = {
|
||||||
|
"wte": 0,
|
||||||
|
"wpe": 0,
|
||||||
|
"h.0": "cpu",
|
||||||
|
"h.1": "cpu",
|
||||||
|
"h.2": "cpu",
|
||||||
|
"h.3": "disk",
|
||||||
|
"h.4": "disk",
|
||||||
|
"ln_f": 0,
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
inputs = torch.tensor([[1, 2, 3]]).to(0)
|
||||||
|
derived_model.save_pretrained(tmp_dir, use_safetensors=True)
|
||||||
|
base_model = AutoModel.from_pretrained(tmp_dir)
|
||||||
|
outputs1 = base_model.to(0)(inputs)
|
||||||
|
|
||||||
|
# with disk offload
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
base_model_with_offload = AutoModel.from_pretrained(
|
||||||
|
tmp_dir, device_map=device_map, offload_folder=offload_folder
|
||||||
|
)
|
||||||
|
outputs2 = base_model_with_offload(inputs)
|
||||||
|
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||||
|
|
||||||
|
# With state dict temp offload
|
||||||
|
new_model_with_offload = AutoModel.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
device_map=device_map,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=True,
|
||||||
|
)
|
||||||
|
outputs2 = new_model_with_offload(inputs)
|
||||||
|
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
|
|||||||
Reference in New Issue
Block a user