[generate] skip compilation on cpu offload (#37709)
* skip compilation on cpu offload * add test * better logic * docstring * boolean logic * add disk offload check * warn users if compilation options are set but compilation doesn happen * fix test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -2245,13 +2245,15 @@ class GenerationTesterMixin:
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
compile_config = CompileConfig()
|
||||
compile_config._compile_all_devices = True
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
model.language_model.generation_config.compile_config = compile_config
|
||||
if not has_defined_cache_implementation:
|
||||
model.language_model.generation_config.cache_implementation = "static"
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU)
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
model.generation_config.compile_config = compile_config
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
@@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# If the generate doesn't infer the DECODER device map correctly, this will fail
|
||||
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_cpu_offload_doesnt_compile(self):
|
||||
"""Test that CPU offload doesn't trigger compilation"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt")
|
||||
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"}
|
||||
|
||||
# Sanity check: if we don't specify a device map, the model will get compiled
|
||||
model_gpu = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
|
||||
)
|
||||
input_ids = tokenized_inputs.input_ids.to(model_gpu.device)
|
||||
_ = model_gpu.generate(input_ids, **generate_kwargs)
|
||||
self.assertTrue(hasattr(model_gpu, "_compiled_call"))
|
||||
|
||||
# If we specify a device map, the model will not be compiled
|
||||
# (as of April 2025, compiling with CPU offload results in a crash)
|
||||
device_map = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
"model.layers.1": "cpu",
|
||||
"model.norm": "cpu",
|
||||
"lm_head": 0,
|
||||
}
|
||||
model_cpu = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
input_ids = tokenized_inputs.input_ids.to(model_cpu.device)
|
||||
_ = model_cpu.generate(input_ids, **generate_kwargs)
|
||||
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user