From 29f322d04d700e9005e59d97352dde2979262c5e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 27 Mar 2025 14:33:20 +0000 Subject: [PATCH] [generate, cache] handle more complex device maps (#37014) --- src/transformers/generation/utils.py | 64 +++++++++++++++++++++++----- tests/generation/test_utils.py | 55 ++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e7780bfede..e39a262527 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1656,11 +1656,10 @@ class GenerationMixin: model_kwargs["cache_position"] = cache_position return model_kwargs - def _get_layer_device_map_for_cache_init(self): + def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]: """ - Taken from `dispatch_model` from accelerate. - This is needed here if we don't want to make changes in accelerate in order to save execution_device - For offloaded case, we need to get the execution device, not just the device where it is offloaded + Returns the device map for each decoder layer, to allocate the cache on the right device. + Inspired from `dispatch_model` in accelerate. """ execution_device_map = None @@ -1674,17 +1673,62 @@ class GenerationMixin: for name, device in self.hf_device_map.items() } - num_hidden_layers = self.config.get_text_config().num_hidden_layers + # No `execution_device_map` -> rely on `self.device` to allocate the cache if execution_device_map is None: return None - elif len(execution_device_map) == 1 and "" in execution_device_map: + + # Single device for all layers + num_hidden_layers = self.config.get_text_config().num_hidden_layers + if len(execution_device_map) == 1 and "" in execution_device_map: return dict.fromkeys(range(num_hidden_layers), execution_device_map[""]) + + # Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device. layer_device_map = {} - for layer in execution_device_map: - for idx in range(num_hidden_layers): - if f".{idx}." in f"{layer}.": - layer_device_map[idx] = execution_device_map[layer] + # Case 1: The model has a `get_decoder` method, we can use it to find the decoder name. + if hasattr(self, "get_decoder"): + decoder_name = None + for name, module in self.named_modules(): + if module is self.get_decoder(): + decoder_name = name break + if decoder_name is None: + raise RuntimeError( + "`model.get_decoder()` is not returning a named module of the model. This is unexpected, please " + "open an issue on GitHub." + ) + + decoder_mapped_modules = [ + module_name for module_name in execution_device_map.keys() if decoder_name in module_name + ] + # The decoder name may be present in `execution_device_map` in two forms: + # a) each layer has a device mapping + if len(decoder_mapped_modules) >= num_hidden_layers: + for idx in range(num_hidden_layers): + for module_name in decoder_mapped_modules: + if f".{idx}." in f"{module_name}.": + layer_device_map[idx] = execution_device_map[module_name] + break + + # b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map, + # then the mapping is done in a parent module + else: + while True: + if decoder_name in execution_device_map: + layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name]) + break + elif "." in decoder_name: + decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module + else: + raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map") + + # Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index) + else: + for layer in execution_device_map: + for idx in range(num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + for idx in range(num_hidden_layers): if idx not in layer_device_map: raise RuntimeError(f"layer {idx} has not been mapped to a device.") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f096d667d6..bccc344f17 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -56,6 +56,7 @@ if is_torch_available(): from transformers import ( AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, @@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids)) self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input)) + @slow + @require_torch_gpu + def test_cache_device_map_with_vision_layer_device_map(self): + """ + Test that the cache device map is correctly set when the vision layer has a device map. Regression test for + #36942 + """ + # gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time + model_id = "google/gemma-3-4b-it" + + # important part of this device map: the `.layers.` pattern is NOT present in the decoder + device_map = { + "vision_tower.vision_model.embeddings": 0, + "vision_tower.vision_model.encoder.layers.0": 0, + "vision_tower.vision_model.encoder.layers.1": 0, + "vision_tower.vision_model.encoder.layers.2": 0, + "vision_tower.vision_model.encoder.layers.3": 0, + "vision_tower.vision_model.encoder.layers.4": 0, + "vision_tower.vision_model.encoder.layers.5": 0, + "vision_tower.vision_model.encoder.layers.6": 0, + "vision_tower.vision_model.encoder.layers.7": 0, + "vision_tower.vision_model.encoder.layers.8": 0, + "vision_tower.vision_model.encoder.layers.9": 0, + "vision_tower.vision_model.encoder.layers.10": 0, + "vision_tower.vision_model.encoder.layers.11": 0, + "vision_tower.vision_model.encoder.layers.12": 0, + "vision_tower.vision_model.encoder.layers.13": 0, + "vision_tower.vision_model.encoder.layers.14": "cpu", + "vision_tower.vision_model.encoder.layers.15": "cpu", + "vision_tower.vision_model.encoder.layers.16": "cpu", + "vision_tower.vision_model.encoder.layers.17": "cpu", + "vision_tower.vision_model.encoder.layers.18": "cpu", + "vision_tower.vision_model.encoder.layers.19": "cpu", + "vision_tower.vision_model.encoder.layers.20": "cpu", + "vision_tower.vision_model.encoder.layers.21": "cpu", + "vision_tower.vision_model.encoder.layers.22": "cpu", + "vision_tower.vision_model.encoder.layers.23": "cpu", + "vision_tower.vision_model.encoder.layers.24": "cpu", + "vision_tower.vision_model.encoder.layers.25": "cpu", + "vision_tower.vision_model.encoder.layers.26": "cpu", + "vision_tower.vision_model.post_layernorm": "cpu", + "multi_modal_projector": "cpu", + "language_model": "cpu", + } + + model = AutoModelForImageTextToText.from_pretrained( + model_id, device_map=device_map, torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device) + + # If the generate doesn't infer the DECODER device map correctly, this will fail + _ = model.generate(**inputs, max_new_tokens=2, do_sample=False) + @require_torch class TokenHealingTestCase(unittest.TestCase):