[generate, cache] handle more complex device maps (#37014)

This commit is contained in:
Joao Gante
2025-03-27 14:33:20 +00:00
committed by GitHub
parent fb8e6c50e4
commit 29f322d04d
2 changed files with 109 additions and 10 deletions

View File

@@ -56,6 +56,7 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
@@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
@slow
@require_torch_gpu
def test_cache_device_map_with_vision_layer_device_map(self):
"""
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
#36942
"""
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
model_id = "google/gemma-3-4b-it"
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
device_map = {
"vision_tower.vision_model.embeddings": 0,
"vision_tower.vision_model.encoder.layers.0": 0,
"vision_tower.vision_model.encoder.layers.1": 0,
"vision_tower.vision_model.encoder.layers.2": 0,
"vision_tower.vision_model.encoder.layers.3": 0,
"vision_tower.vision_model.encoder.layers.4": 0,
"vision_tower.vision_model.encoder.layers.5": 0,
"vision_tower.vision_model.encoder.layers.6": 0,
"vision_tower.vision_model.encoder.layers.7": 0,
"vision_tower.vision_model.encoder.layers.8": 0,
"vision_tower.vision_model.encoder.layers.9": 0,
"vision_tower.vision_model.encoder.layers.10": 0,
"vision_tower.vision_model.encoder.layers.11": 0,
"vision_tower.vision_model.encoder.layers.12": 0,
"vision_tower.vision_model.encoder.layers.13": 0,
"vision_tower.vision_model.encoder.layers.14": "cpu",
"vision_tower.vision_model.encoder.layers.15": "cpu",
"vision_tower.vision_model.encoder.layers.16": "cpu",
"vision_tower.vision_model.encoder.layers.17": "cpu",
"vision_tower.vision_model.encoder.layers.18": "cpu",
"vision_tower.vision_model.encoder.layers.19": "cpu",
"vision_tower.vision_model.encoder.layers.20": "cpu",
"vision_tower.vision_model.encoder.layers.21": "cpu",
"vision_tower.vision_model.encoder.layers.22": "cpu",
"vision_tower.vision_model.encoder.layers.23": "cpu",
"vision_tower.vision_model.encoder.layers.24": "cpu",
"vision_tower.vision_model.encoder.layers.25": "cpu",
"vision_tower.vision_model.encoder.layers.26": "cpu",
"vision_tower.vision_model.post_layernorm": "cpu",
"multi_modal_projector": "cpu",
"language_model": "cpu",
}
model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map=device_map, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
# If the generate doesn't infer the DECODER device map correctly, this will fail
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
@require_torch
class TokenHealingTestCase(unittest.TestCase):