[generate, cache] handle more complex device maps (#37014)
This commit is contained in:
@@ -56,6 +56,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -4720,6 +4721,60 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_cache_device_map_with_vision_layer_device_map(self):
|
||||
"""
|
||||
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
|
||||
#36942
|
||||
"""
|
||||
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
|
||||
device_map = {
|
||||
"vision_tower.vision_model.embeddings": 0,
|
||||
"vision_tower.vision_model.encoder.layers.0": 0,
|
||||
"vision_tower.vision_model.encoder.layers.1": 0,
|
||||
"vision_tower.vision_model.encoder.layers.2": 0,
|
||||
"vision_tower.vision_model.encoder.layers.3": 0,
|
||||
"vision_tower.vision_model.encoder.layers.4": 0,
|
||||
"vision_tower.vision_model.encoder.layers.5": 0,
|
||||
"vision_tower.vision_model.encoder.layers.6": 0,
|
||||
"vision_tower.vision_model.encoder.layers.7": 0,
|
||||
"vision_tower.vision_model.encoder.layers.8": 0,
|
||||
"vision_tower.vision_model.encoder.layers.9": 0,
|
||||
"vision_tower.vision_model.encoder.layers.10": 0,
|
||||
"vision_tower.vision_model.encoder.layers.11": 0,
|
||||
"vision_tower.vision_model.encoder.layers.12": 0,
|
||||
"vision_tower.vision_model.encoder.layers.13": 0,
|
||||
"vision_tower.vision_model.encoder.layers.14": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.15": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.16": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.17": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.18": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.19": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.20": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.21": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.22": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.23": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.24": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.25": "cpu",
|
||||
"vision_tower.vision_model.encoder.layers.26": "cpu",
|
||||
"vision_tower.vision_model.post_layernorm": "cpu",
|
||||
"multi_modal_projector": "cpu",
|
||||
"language_model": "cpu",
|
||||
}
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id, device_map=device_map, torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
|
||||
|
||||
# If the generate doesn't infer the DECODER device map correctly, this will fail
|
||||
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user