[generate] skip compilation on cpu offload (#37709)

* skip compilation on cpu offload

* add test

* better logic

* docstring

* boolean logic

* add disk offload check

* warn users if compilation options are set but compilation doesn happen

* fix test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-04-24 14:08:17 +01:00
committed by GitHub
parent 7c62e69326
commit 8bdd4f2acd
4 changed files with 91 additions and 21 deletions

View File

@@ -2245,13 +2245,15 @@ class GenerationTesterMixin:
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
compile_config = CompileConfig()
compile_config._compile_all_devices = True
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
model.language_model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation:
model.language_model.generation_config.cache_implementation = "static"
else:
# force compilation (e.g. fast CI, CPU)
model.generation_config.compile_config._compile_all_devices = True
model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
@@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase):
# If the generate doesn't infer the DECODER device map correctly, this will fail
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
@require_torch_gpu
def test_cpu_offload_doesnt_compile(self):
"""Test that CPU offload doesn't trigger compilation"""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt")
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"}
# Sanity check: if we don't specify a device map, the model will get compiled
model_gpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
input_ids = tokenized_inputs.input_ids.to(model_gpu.device)
_ = model_gpu.generate(input_ids, **generate_kwargs)
self.assertTrue(hasattr(model_gpu, "_compiled_call"))
# If we specify a device map, the model will not be compiled
# (as of April 2025, compiling with CPU offload results in a crash)
device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": "cpu",
"model.norm": "cpu",
"lm_head": 0,
}
model_cpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
input_ids = tokenized_inputs.input_ids.to(model_cpu.device)
_ = model_cpu.generate(input_ids, **generate_kwargs)
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
@require_torch
class TokenHealingTestCase(unittest.TestCase):