[generate, cache] handle more complex device maps (#37014)
This commit is contained in:
@@ -1656,11 +1656,10 @@ class GenerationMixin:
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
def _get_layer_device_map_for_cache_init(self):
|
||||
def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]:
|
||||
"""
|
||||
Taken from `dispatch_model` from accelerate.
|
||||
This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||
For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||
Returns the device map for each decoder layer, to allocate the cache on the right device.
|
||||
Inspired from `dispatch_model` in accelerate.
|
||||
"""
|
||||
execution_device_map = None
|
||||
|
||||
@@ -1674,17 +1673,62 @@ class GenerationMixin:
|
||||
for name, device in self.hf_device_map.items()
|
||||
}
|
||||
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
# No `execution_device_map` -> rely on `self.device` to allocate the cache
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
|
||||
# Single device for all layers
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
if len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
|
||||
|
||||
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
|
||||
if hasattr(self, "get_decoder"):
|
||||
decoder_name = None
|
||||
for name, module in self.named_modules():
|
||||
if module is self.get_decoder():
|
||||
decoder_name = name
|
||||
break
|
||||
if decoder_name is None:
|
||||
raise RuntimeError(
|
||||
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
|
||||
"open an issue on GitHub."
|
||||
)
|
||||
|
||||
decoder_mapped_modules = [
|
||||
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
|
||||
]
|
||||
# The decoder name may be present in `execution_device_map` in two forms:
|
||||
# a) each layer has a device mapping
|
||||
if len(decoder_mapped_modules) >= num_hidden_layers:
|
||||
for idx in range(num_hidden_layers):
|
||||
for module_name in decoder_mapped_modules:
|
||||
if f".{idx}." in f"{module_name}.":
|
||||
layer_device_map[idx] = execution_device_map[module_name]
|
||||
break
|
||||
|
||||
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
|
||||
# then the mapping is done in a parent module
|
||||
else:
|
||||
while True:
|
||||
if decoder_name in execution_device_map:
|
||||
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
|
||||
break
|
||||
elif "." in decoder_name:
|
||||
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
|
||||
else:
|
||||
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
|
||||
|
||||
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
|
||||
else:
|
||||
for layer in execution_device_map:
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
|
||||
for idx in range(num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
|
||||
Reference in New Issue
Block a user