[generate, cache] handle more complex device maps (#37014)
This commit is contained in:
@@ -1656,11 +1656,10 @@ class GenerationMixin:
|
|||||||
model_kwargs["cache_position"] = cache_position
|
model_kwargs["cache_position"] = cache_position
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _get_layer_device_map_for_cache_init(self):
|
def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]:
|
||||||
"""
|
"""
|
||||||
Taken from `dispatch_model` from accelerate.
|
Returns the device map for each decoder layer, to allocate the cache on the right device.
|
||||||
This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
Inspired from `dispatch_model` in accelerate.
|
||||||
For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
|
||||||
"""
|
"""
|
||||||
execution_device_map = None
|
execution_device_map = None
|
||||||
|
|
||||||
@@ -1674,17 +1673,62 @@ class GenerationMixin:
|
|||||||
for name, device in self.hf_device_map.items()
|
for name, device in self.hf_device_map.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
# No `execution_device_map` -> rely on `self.device` to allocate the cache
|
||||||
if execution_device_map is None:
|
if execution_device_map is None:
|
||||||
return None
|
return None
|
||||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
|
||||||
|
# Single device for all layers
|
||||||
|
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||||
|
if len(execution_device_map) == 1 and "" in execution_device_map:
|
||||||
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
|
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
|
||||||
|
|
||||||
|
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
|
||||||
layer_device_map = {}
|
layer_device_map = {}
|
||||||
|
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
|
||||||
|
if hasattr(self, "get_decoder"):
|
||||||
|
decoder_name = None
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if module is self.get_decoder():
|
||||||
|
decoder_name = name
|
||||||
|
break
|
||||||
|
if decoder_name is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
|
||||||
|
"open an issue on GitHub."
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_mapped_modules = [
|
||||||
|
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
|
||||||
|
]
|
||||||
|
# The decoder name may be present in `execution_device_map` in two forms:
|
||||||
|
# a) each layer has a device mapping
|
||||||
|
if len(decoder_mapped_modules) >= num_hidden_layers:
|
||||||
|
for idx in range(num_hidden_layers):
|
||||||
|
for module_name in decoder_mapped_modules:
|
||||||
|
if f".{idx}." in f"{module_name}.":
|
||||||
|
layer_device_map[idx] = execution_device_map[module_name]
|
||||||
|
break
|
||||||
|
|
||||||
|
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
|
||||||
|
# then the mapping is done in a parent module
|
||||||
|
else:
|
||||||
|
while True:
|
||||||
|
if decoder_name in execution_device_map:
|
||||||
|
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
|
||||||
|
break
|
||||||
|
elif "." in decoder_name:
|
||||||
|
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
|
||||||
|
|
||||||
|
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
|
||||||
|
else:
|
||||||
for layer in execution_device_map:
|
for layer in execution_device_map:
|
||||||
for idx in range(num_hidden_layers):
|
for idx in range(num_hidden_layers):
|
||||||
if f".{idx}." in f"{layer}.":
|
if f".{idx}." in f"{layer}.":
|
||||||
layer_device_map[idx] = execution_device_map[layer]
|
layer_device_map[idx] = execution_device_map[layer]
|
||||||
break
|
break
|
||||||
|
|
||||||
for idx in range(num_hidden_layers):
|
for idx in range(num_hidden_layers):
|
||||||
if idx not in layer_device_map:
|
if idx not in layer_device_map:
|
||||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForImageTextToText,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSpeechSeq2Seq,
|
AutoModelForSpeechSeq2Seq,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_cache_device_map_with_vision_layer_device_map(self):
|
||||||
|
"""
|
||||||
|
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
|
||||||
|
#36942
|
||||||
|
"""
|
||||||
|
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
|
||||||
|
model_id = "google/gemma-3-4b-it"
|
||||||
|
|
||||||
|
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
|
||||||
|
device_map = {
|
||||||
|
"vision_tower.vision_model.embeddings": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.0": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.1": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.2": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.3": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.4": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.5": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.6": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.7": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.8": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.9": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.10": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.11": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.12": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.13": 0,
|
||||||
|
"vision_tower.vision_model.encoder.layers.14": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.15": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.16": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.17": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.18": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.19": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.20": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.21": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.22": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.23": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.24": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.25": "cpu",
|
||||||
|
"vision_tower.vision_model.encoder.layers.26": "cpu",
|
||||||
|
"vision_tower.vision_model.post_layernorm": "cpu",
|
||||||
|
"multi_modal_projector": "cpu",
|
||||||
|
"language_model": "cpu",
|
||||||
|
}
|
||||||
|
|
||||||
|
model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
model_id, device_map=device_map, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# If the generate doesn't infer the DECODER device map correctly, this will fail
|
||||||
|
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user