Init cache on meta device (#35164)

* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-22 09:49:17 +01:00
committed by GitHub
parent 870e2c8ea0
commit 373e50e970
10 changed files with 111 additions and 111 deletions

View File

@@ -1069,12 +1069,15 @@ class StaticCache(Cache):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Example:
```python
@@ -1096,6 +1099,7 @@ class StaticCache(Cache):
"""
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
@@ -1122,6 +1126,7 @@ class StaticCache(Cache):
)
self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
@@ -1136,7 +1141,7 @@ class StaticCache(Cache):
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = device
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
@@ -1181,6 +1186,9 @@ class StaticCache(Cache):
"""
cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
@@ -1209,6 +1217,8 @@ class StaticCache(Cache):
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def get_max_cache_shape(self) -> Optional[int]:
@@ -1217,6 +1227,7 @@ class StaticCache(Cache):
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@@ -1257,6 +1268,8 @@ class SlidingWindowCache(StaticCache):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
@@ -1321,8 +1334,15 @@ class SlidingWindowCache(StaticCache):
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
@@ -1365,6 +1385,7 @@ class SlidingWindowCache(StaticCache):
def reset(self):
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@@ -1561,8 +1582,10 @@ class HybridCache(Cache):
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
@@ -1590,12 +1613,13 @@ class HybridCache(Cache):
"""
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = "cpu",
device: Union[torch.device, str] = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@@ -1623,9 +1647,11 @@ class HybridCache(Cache):
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
@@ -1640,7 +1666,7 @@ class HybridCache(Cache):
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = device
layer_device = self.device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
@@ -1696,8 +1722,16 @@ class HybridCache(Cache):
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if sliding_window:
update_fn = self._sliding_update
else:
@@ -1725,11 +1759,15 @@ class HybridCache(Cache):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@@ -1757,10 +1795,14 @@ class MambaCache:
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
Attributes:
dtype: (`torch.dtype`):
The default `dtype` used to initializing the cache.
device (`torch.device`):
The default device on which the cache was initialized.
intermediate_size: (`int`):
Model's intermediate_size taken from config.
ssm_state_size: (`int`):
@@ -1809,30 +1851,40 @@ class MambaCache:
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
device=self.device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
device=self.device,
dtype=dtype,
)
torch._dynamo.mark_static_address(self.conv_states)
torch._dynamo.mark_static_address(self.ssm_states)
torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
if self.conv_states[layer_idx].device.type == "meta":
self.conv_states[layer_idx] = torch.zeros_like(
self.conv_states[layer_idx],
device=new_conv_state.device,
)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
@@ -1843,12 +1895,15 @@ class MambaCache:
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
for layer_idx in range(len(self.conv_states)):
if self.conv_states[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
@property
def batch_size(self):
@@ -1920,6 +1975,7 @@ class OffloadedStaticCache(StaticCache):
```
"""
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
@@ -1930,9 +1986,10 @@ class OffloadedStaticCache(StaticCache):
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

View File

@@ -1633,45 +1633,12 @@ class GenerationMixin:
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype
def get_layer_device_map(execution_device_map: Optional[dict] = None):
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if execution_device_map is None:
return None
elif len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
layer_device_map = {}
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map
execution_device_map = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_map = get_layer_device_map(execution_device_map)
cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
"layer_device_map": layer_device_map,
"device": device if cache_implementation == "offloaded_static" else None,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:

View File

@@ -73,6 +73,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:

View File

@@ -582,7 +582,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

View File

@@ -461,7 +461,6 @@ class Cohere2Model(Gemma2Model):
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

View File

@@ -579,7 +579,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

View File

@@ -405,7 +405,6 @@ class Gemma2Model(GemmaModel):
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)

View File

@@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase):
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache
# Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow
@require_read_token
def test_export_static_cache(self):
@@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase):
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
"device": device,
},
),
)

View File

@@ -4635,6 +4635,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else fa2_model.config._attn_implementation == "flash_attention_2"
)
self.assertTrue(fa2_correctly_converted)
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
@@ -4653,6 +4658,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True
break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
)
self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self):

View File

@@ -198,6 +198,7 @@ class CacheTest(unittest.TestCase):
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
},
),
)
@@ -310,11 +311,12 @@ class CacheIntegrationTest(unittest.TestCase):
do_sample=False,
max_new_tokens=20,
num_return_sequences=2,
num_beams=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
]
self.assertListEqual(decoded, expected_text)
@@ -380,8 +382,6 @@ class CacheIntegrationTest(unittest.TestCase):
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
@@ -427,8 +427,6 @@ class CacheIntegrationTest(unittest.TestCase):
[
("eager", "static"),
("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
]
)
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
@@ -462,26 +460,6 @@ class CacheIntegrationTest(unittest.TestCase):
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
def test_dynamic_cache_extra_left_padding(self):
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
EXPECTED_GENERATION = [
@@ -519,7 +497,6 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(
[
"static",
"offloaded-static",
]
)
def test_static_cache_extra_left_padding(self, cache_implementation):