[generate, cache] handle more complex device maps (#37014)
This commit is contained in:
@@ -1656,11 +1656,10 @@ class GenerationMixin:
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
def _get_layer_device_map_for_cache_init(self):
|
||||
def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]:
|
||||
"""
|
||||
Taken from `dispatch_model` from accelerate.
|
||||
This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||
For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||
Returns the device map for each decoder layer, to allocate the cache on the right device.
|
||||
Inspired from `dispatch_model` in accelerate.
|
||||
"""
|
||||
execution_device_map = None
|
||||
|
||||
@@ -1674,17 +1673,62 @@ class GenerationMixin:
|
||||
for name, device in self.hf_device_map.items()
|
||||
}
|
||||
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
# No `execution_device_map` -> rely on `self.device` to allocate the cache
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
|
||||
# Single device for all layers
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
if len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
|
||||
|
||||
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
|
||||
if hasattr(self, "get_decoder"):
|
||||
decoder_name = None
|
||||
for name, module in self.named_modules():
|
||||
if module is self.get_decoder():
|
||||
decoder_name = name
|
||||
break
|
||||
if decoder_name is None:
|
||||
raise RuntimeError(
|
||||
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
|
||||
"open an issue on GitHub."
|
||||
)
|
||||
|
||||
decoder_mapped_modules = [
|
||||
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
|
||||
]
|
||||
# The decoder name may be present in `execution_device_map` in two forms:
|
||||
# a) each layer has a device mapping
|
||||
if len(decoder_mapped_modules) >= num_hidden_layers:
|
||||
for idx in range(num_hidden_layers):
|
||||
for module_name in decoder_mapped_modules:
|
||||
if f".{idx}." in f"{module_name}.":
|
||||
layer_device_map[idx] = execution_device_map[module_name]
|
||||
break
|
||||
|
||||
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
|
||||
# then the mapping is done in a parent module
|
||||
else:
|
||||
while True:
|
||||
if decoder_name in execution_device_map:
|
||||
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
|
||||
break
|
||||
elif "." in decoder_name:
|
||||
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
|
||||
else:
|
||||
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
|
||||
|
||||
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
|
||||
else:
|
||||
for layer in execution_device_map:
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
|
||||
for idx in range(num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
|
||||
@@ -56,6 +56,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_cache_device_map_with_vision_layer_device_map(self):
|
||||
"""
|
||||
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
|
||||
#36942
|
||||
"""
|
||||
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
|
||||
device_map = {
|
||||
"vision_tower.vision_model.embeddings": 0,
|
||||
"vision_tower.vision_model.encoder.layers.0": 0,
|
||||
"vision_tower.vision_model.encoder.layers.1": 0,
|
||||
"vision_tower.vision_model.encoder.layers.2": 0,
|
||||
"vision_tower.vision_model.encoder.layers.3": 0,
|
||||
"vision_tower.vision_model.encoder.layers.4": 0,
|
||||
"vision_tower.vision_model.encoder.layers.5": 0,
|
||||
"vision_tower.vision_model.encoder.layers.6": 0,
|
||||
"vision_tower.vision_model.encoder.layers.7": 0,
|
||||
"vision_tower.vision_model.encoder.layers.8": 0,
|
||||
"vision_tower.vision_model.encoder.layers.9": 0,
|
||||
"vision_tower.vision_model.encoder.layers.10": 0,
|
||||
"vision_tower.vision_model.encoder.layers.11": 0,
|
||||
"vision_tower.vision_model.encoder.layers.12": 0,
|
||||
"vision_tower.vision_model.encoder.layers.13": 0,
|
||||
"vision_tower.vision_model.encoder.layers.14": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.15": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.16": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.17": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.18": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.19": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.20": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.21": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.22": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.23": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.24": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.25": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.26": "cpu",
|
||||
"vision_tower.vision_model.post_layernorm": "cpu",
|
||||
"multi_modal_projector": "cpu",
|
||||
"language_model": "cpu",
|
||||
}
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id, device_map=device_map, torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
|
||||
|
||||
# If the generate doesn't infer the DECODER device map correctly, this will fail
|
||||
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user