Init cache on meta device (#35164)

* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-22 09:49:17 +01:00
committed by GitHub
parent 870e2c8ea0
commit 373e50e970
10 changed files with 111 additions and 111 deletions

View File

@@ -1069,12 +1069,15 @@ class StaticCache(Cache):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`): device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Example: Example:
```python ```python
@@ -1096,6 +1099,7 @@ class StaticCache(Cache):
""" """
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
@@ -1122,6 +1126,7 @@ class StaticCache(Cache):
) )
self.dtype = dtype self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
@@ -1136,7 +1141,7 @@ class StaticCache(Cache):
if layer_device_map is not None: if layer_device_map is not None:
layer_device = layer_device_map[idx] layer_device = layer_device_map[idx]
else: else:
layer_device = device layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes: # Notes:
@@ -1181,6 +1186,9 @@ class StaticCache(Cache):
""" """
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
@@ -1209,6 +1217,8 @@ class StaticCache(Cache):
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension. # limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position` # TODO: deprecate this function in favor of `cache_position`
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def get_max_cache_shape(self) -> Optional[int]: def get_max_cache_shape(self) -> Optional[int]:
@@ -1217,6 +1227,7 @@ class StaticCache(Cache):
def reset(self): def reset(self):
"""Resets the cache values while preserving the objects""" """Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@@ -1257,6 +1268,8 @@ class SlidingWindowCache(StaticCache):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`): device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
@@ -1321,8 +1334,15 @@ class SlidingWindowCache(StaticCache):
cache_kwargs: Optional[Dict[str, Any]] = None, cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len: if cache_position.shape[0] > self.max_cache_len:
@@ -1365,6 +1385,7 @@ class SlidingWindowCache(StaticCache):
def reset(self): def reset(self):
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@@ -1561,8 +1582,10 @@ class HybridCache(Cache):
smaller batch size is used. smaller batch size is used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
dtype (torch.dtype, *optional*, defaults to `torch.float32`): dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
@@ -1590,12 +1613,13 @@ class HybridCache(Cache):
""" """
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
batch_size: int = None, batch_size: int = None,
max_cache_len: int = None, max_cache_len: int = None,
device: Union[torch.device, str] = "cpu", device: Union[torch.device, str] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@@ -1623,9 +1647,11 @@ class HybridCache(Cache):
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
) )
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor( self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
) )
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
@@ -1640,7 +1666,7 @@ class HybridCache(Cache):
if layer_device_map is not None: if layer_device_map is not None:
layer_device = layer_device_map[i] layer_device = layer_device_map[i]
else: else:
layer_device = device layer_device = self.device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
@@ -1696,8 +1722,16 @@ class HybridCache(Cache):
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window") sliding_window = cache_kwargs.get("sliding_window")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if sliding_window: if sliding_window:
update_fn = self._sliding_update update_fn = self._sliding_update
else: else:
@@ -1725,11 +1759,15 @@ class HybridCache(Cache):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported." "Using the `layer_idx` argument is not supported."
) )
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self): def reset(self):
"""Resets the cache values while preserving the objects""" """Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@@ -1757,10 +1795,14 @@ class MambaCache:
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
Attributes: Attributes:
dtype: (`torch.dtype`): dtype: (`torch.dtype`):
The default `dtype` used to initializing the cache. The default `dtype` used to initializing the cache.
device (`torch.device`):
The default device on which the cache was initialized.
intermediate_size: (`int`): intermediate_size: (`int`):
Model's intermediate_size taken from config. Model's intermediate_size taken from config.
ssm_state_size: (`int`): ssm_state_size: (`int`):
@@ -1809,30 +1851,40 @@ class MambaCache:
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: torch.Tensor = torch.zeros( self.conv_states: List[torch.Tensor] = []
config.num_hidden_layers, self.ssm_states: List[torch.Tensor] = []
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size, self.max_batch_size,
self.intermediate_size, self.intermediate_size,
self.conv_kernel_size, self.conv_kernel_size,
device=device, device=self.device,
dtype=dtype, dtype=dtype,
) )
self.ssm_states: torch.Tensor = torch.zeros( ssm_state: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size, self.max_batch_size,
self.intermediate_size, self.intermediate_size,
self.ssm_state_size, self.ssm_state_size,
device=device, device=self.device,
dtype=dtype, dtype=dtype,
) )
torch._dynamo.mark_static_address(self.conv_states) torch._dynamo.mark_static_address(conv_state)
torch._dynamo.mark_static_address(self.ssm_states) torch._dynamo.mark_static_address(ssm_state)
self.conv_states.append(conv_state)
self.ssm_states.append(ssm_state)
def update_conv_state( def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor: ) -> torch.Tensor:
if self.conv_states[layer_idx].device.type == "meta":
self.conv_states[layer_idx] = torch.zeros_like(
self.conv_states[layer_idx],
device=new_conv_state.device,
)
conv_state = self.conv_states[layer_idx] conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
@@ -1843,12 +1895,15 @@ class MambaCache:
return self.conv_states[layer_idx] return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
return self.ssm_states[layer_idx] return self.ssm_states[layer_idx]
def reset(self): def reset(self):
self.conv_states.zero_() for layer_idx in range(len(self.conv_states)):
self.ssm_states.zero_() if self.conv_states[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
@property @property
def batch_size(self): def batch_size(self):
@@ -1920,6 +1975,7 @@ class OffloadedStaticCache(StaticCache):
``` ```
""" """
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
@@ -1930,9 +1986,10 @@ class OffloadedStaticCache(StaticCache):
offload_device: Union[str, torch.device] = torch.device("cpu"), offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device) self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32 self.dtype = dtype if dtype is not None else torch.float32

View File

@@ -1633,45 +1633,12 @@ class GenerationMixin:
# models. May cause trobles with non-text modalities. # models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype cache_dtype = self.get_output_embeddings().weight.dtype
def get_layer_device_map(execution_device_map: Optional[dict] = None):
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if execution_device_map is None:
return None
elif len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
layer_device_map = {}
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map
execution_device_map = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_map = get_layer_device_map(execution_device_map)
cache_kwargs = { cache_kwargs = {
"config": self.config.get_text_config(), "config": self.config.get_text_config(),
"max_batch_size": batch_size, "max_batch_size": batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype, "dtype": cache_dtype,
"layer_device_map": layer_device_map, "device": device if cache_implementation == "offloaded_static" else None,
} }
self._cache = cache_cls(**cache_kwargs) self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache: if requires_cross_attention_cache:

View File

@@ -73,6 +73,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
batch_size=self.model.generation_config.cache_config.batch_size, batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len, max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype, dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
) )
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal: if self.is_causal:

View File

@@ -582,7 +582,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
self.config, self.config,
max_batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=seq_len, max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )

View File

@@ -461,7 +461,6 @@ class Cohere2Model(Gemma2Model):
self.config, self.config,
max_batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=seq_len, max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )

View File

@@ -579,7 +579,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
self.config, self.config,
max_batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=seq_len, max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )

View File

@@ -405,7 +405,6 @@ class Gemma2Model(GemmaModel):
self.config, self.config,
max_batch_size=batch_size, max_batch_size=batch_size,
max_cache_len=seq_len, max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )

View File

@@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase):
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
generated_ids = model.generate( generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow @slow
@require_read_token @require_read_token
def test_export_static_cache(self): def test_export_static_cache(self):
@@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase):
cache_config={ cache_config={
"batch_size": batch_size, "batch_size": batch_size,
"max_cache_len": max_generation_length, "max_cache_len": max_generation_length,
"device": device,
}, },
), ),
) )

View File

@@ -4635,6 +4635,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True fa2_correctly_converted = True
break break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else fa2_model.config._attn_implementation == "flash_attention_2"
)
self.assertTrue(fa2_correctly_converted) self.assertTrue(fa2_correctly_converted)
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
@@ -4653,6 +4658,11 @@ class ModelTesterMixin:
fa2_correctly_converted = True fa2_correctly_converted = True
break break
fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
)
self.assertFalse(fa2_correctly_converted) self.assertFalse(fa2_correctly_converted)
def _get_custom_4d_mask_test_data(self): def _get_custom_4d_mask_test_data(self):

View File

@@ -198,6 +198,7 @@ class CacheTest(unittest.TestCase):
cache_config={ cache_config={
"batch_size": batch_size, "batch_size": batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": device,
}, },
), ),
) )
@@ -310,11 +311,12 @@ class CacheIntegrationTest(unittest.TestCase):
do_sample=False, do_sample=False,
max_new_tokens=20, max_new_tokens=20,
num_return_sequences=2, num_return_sequences=2,
num_beams=2,
) )
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [ expected_text = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", "Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", "Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
] ]
self.assertListEqual(decoded, expected_text) self.assertListEqual(decoded, expected_text)
@@ -380,8 +382,6 @@ class CacheIntegrationTest(unittest.TestCase):
[ [
("eager", "static"), ("eager", "static"),
("sdpa", "static"), ("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
] ]
) )
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
@@ -427,8 +427,6 @@ class CacheIntegrationTest(unittest.TestCase):
[ [
("eager", "static"), ("eager", "static"),
("sdpa", "static"), ("sdpa", "static"),
("eager", "offloaded-static"),
("sdpa", "offloaded-static"),
] ]
) )
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
@@ -462,26 +460,6 @@ class CacheIntegrationTest(unittest.TestCase):
with self.subTest(f"{attn_implementation}, static, eager"): with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION) self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
def test_dynamic_cache_extra_left_padding(self): def test_dynamic_cache_extra_left_padding(self):
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache""" """Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
EXPECTED_GENERATION = [ EXPECTED_GENERATION = [
@@ -519,7 +497,6 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
"static", "static",
"offloaded-static",
] ]
) )
def test_static_cache_extra_left_padding(self, cache_implementation): def test_static_cache_extra_left_padding(self, cache_implementation):