[generate, cache] handle more complex device maps (#37014)

This commit is contained in:
Joao Gante
2025-03-27 14:33:20 +00:00
committed by GitHub
parent fb8e6c50e4
commit 29f322d04d
2 changed files with 109 additions and 10 deletions

View File

@@ -1656,11 +1656,10 @@ class GenerationMixin:
model_kwargs["cache_position"] = cache_position
return model_kwargs
def _get_layer_device_map_for_cache_init(self):
def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]:
"""
Taken from `dispatch_model` from accelerate.
This is needed here if we don't want to make changes in accelerate in order to save execution_device
For offloaded case, we need to get the execution device, not just the device where it is offloaded
Returns the device map for each decoder layer, to allocate the cache on the right device.
Inspired from `dispatch_model` in accelerate.
"""
execution_device_map = None
@@ -1674,17 +1673,62 @@ class GenerationMixin:
for name, device in self.hf_device_map.items()
}
num_hidden_layers = self.config.get_text_config().num_hidden_layers
# No `execution_device_map` -> rely on `self.device` to allocate the cache
if execution_device_map is None:
return None
elif len(execution_device_map) == 1 and "" in execution_device_map:
# Single device for all layers
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if len(execution_device_map) == 1 and "" in execution_device_map:
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
layer_device_map = {}
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
if hasattr(self, "get_decoder"):
decoder_name = None
for name, module in self.named_modules():
if module is self.get_decoder():
decoder_name = name
break
if decoder_name is None:
raise RuntimeError(
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
"open an issue on GitHub."
)
decoder_mapped_modules = [
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
]
# The decoder name may be present in `execution_device_map` in two forms:
# a) each layer has a device mapping
if len(decoder_mapped_modules) >= num_hidden_layers:
for idx in range(num_hidden_layers):
for module_name in decoder_mapped_modules:
if f".{idx}." in f"{module_name}.":
layer_device_map[idx] = execution_device_map[module_name]
break
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
# then the mapping is done in a parent module
else:
while True:
if decoder_name in execution_device_map:
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
break
elif "." in decoder_name:
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
else:
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
else:
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")

View File

@@ -56,6 +56,7 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
@@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
@slow
@require_torch_gpu
def test_cache_device_map_with_vision_layer_device_map(self):
"""
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
#36942
"""
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
model_id = "google/gemma-3-4b-it"
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
device_map = {
"vision_tower.vision_model.embeddings": 0,
"vision_tower.vision_model.encoder.layers.0": 0,
"vision_tower.vision_model.encoder.layers.1": 0,
"vision_tower.vision_model.encoder.layers.2": 0,
"vision_tower.vision_model.encoder.layers.3": 0,
"vision_tower.vision_model.encoder.layers.4": 0,
"vision_tower.vision_model.encoder.layers.5": 0,
"vision_tower.vision_model.encoder.layers.6": 0,
"vision_tower.vision_model.encoder.layers.7": 0,
"vision_tower.vision_model.encoder.layers.8": 0,
"vision_tower.vision_model.encoder.layers.9": 0,
"vision_tower.vision_model.encoder.layers.10": 0,
"vision_tower.vision_model.encoder.layers.11": 0,
"vision_tower.vision_model.encoder.layers.12": 0,
"vision_tower.vision_model.encoder.layers.13": 0,
"vision_tower.vision_model.encoder.layers.14": "cpu",
"vision_tower.vision_model.encoder.layers.15": "cpu",
"vision_tower.vision_model.encoder.layers.16": "cpu",
"vision_tower.vision_model.encoder.layers.17": "cpu",
"vision_tower.vision_model.encoder.layers.18": "cpu",
"vision_tower.vision_model.encoder.layers.19": "cpu",
"vision_tower.vision_model.encoder.layers.20": "cpu",
"vision_tower.vision_model.encoder.layers.21": "cpu",
"vision_tower.vision_model.encoder.layers.22": "cpu",
"vision_tower.vision_model.encoder.layers.23": "cpu",
"vision_tower.vision_model.encoder.layers.24": "cpu",
"vision_tower.vision_model.encoder.layers.25": "cpu",
"vision_tower.vision_model.encoder.layers.26": "cpu",
"vision_tower.vision_model.post_layernorm": "cpu",
"multi_modal_projector": "cpu",
"language_model": "cpu",
}
model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map=device_map, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
# If the generate doesn't infer the DECODER device map correctly, this will fail
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
@require_torch
class TokenHealingTestCase(unittest.TestCase):