[Cache] Don't initialize the cache on meta device (#36543)
This commit is contained in:
@@ -10,7 +10,6 @@ from packaging import version
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||||
from .utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
|
|
||||||
if is_hqq_available():
|
if is_hqq_available():
|
||||||
@@ -1064,18 +1063,19 @@ class StaticCache(Cache):
|
|||||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
batch_size (`int`):
|
batch_size (`int`):
|
||||||
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||||
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
|
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
|
||||||
|
number of beams if you are running beam search
|
||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device` or `str`):
|
device (`torch.device` or `str`):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
should pass the `layer_device_map` argument instead.
|
||||||
device by default, and then moved to input device when updating.
|
|
||||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||||
|
checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -1101,7 +1101,6 @@ class StaticCache(Cache):
|
|||||||
is_compileable = True
|
is_compileable = True
|
||||||
|
|
||||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -1128,7 +1127,6 @@ class StaticCache(Cache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads
|
config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
@@ -1139,11 +1137,12 @@ class StaticCache(Cache):
|
|||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||||
cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
|
device = torch.device(device) if device is not None else None
|
||||||
for idx in range(config.num_hidden_layers):
|
for idx in range(config.num_hidden_layers):
|
||||||
if layer_device_map is not None:
|
if layer_device_map is not None:
|
||||||
layer_device = layer_device_map[idx]
|
layer_device = layer_device_map[idx]
|
||||||
else:
|
else:
|
||||||
layer_device = self.device
|
layer_device = device
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||||
@@ -1178,12 +1177,7 @@ class StaticCache(Cache):
|
|||||||
Return:
|
Return:
|
||||||
A tuple containing the updated key and value states.
|
A tuple containing the updated key and value states.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
if self.key_cache[layer_idx].device.type == "meta":
|
|
||||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
|
||||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
|
||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
key_states = key_states.to(k_out.dtype)
|
key_states = key_states.to(k_out.dtype)
|
||||||
@@ -1211,8 +1205,6 @@ class StaticCache(Cache):
|
|||||||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||||
# limit the check to the first batch member and head dimension.
|
# limit the check to the first batch member and head dimension.
|
||||||
# TODO: deprecate this function in favor of `cache_position`
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
if self.key_cache[layer_idx].device.type == "meta":
|
|
||||||
return 0
|
|
||||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||||
|
|
||||||
def get_max_cache_shape(self) -> Optional[int]:
|
def get_max_cache_shape(self) -> Optional[int]:
|
||||||
@@ -1221,10 +1213,9 @@ class StaticCache(Cache):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets the cache values while preserving the objects"""
|
"""Resets the cache values while preserving the objects"""
|
||||||
for layer_idx in range(len(self.key_cache)):
|
for layer_idx in range(len(self.key_cache)):
|
||||||
if self.key_cache[layer_idx].device.type != "meta":
|
# In-place ops prevent breaking the static address
|
||||||
# In-place ops prevent breaking the static address
|
self.key_cache[layer_idx].zero_()
|
||||||
self.key_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1261,14 +1252,14 @@ class SlidingWindowCache(StaticCache):
|
|||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device` or `str`):
|
device (`torch.device` or `str`):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
should pass the `layer_device_map` argument instead.
|
||||||
device by default, and then moved to input device when updating.
|
|
||||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||||
|
checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1329,11 +1320,6 @@ class SlidingWindowCache(StaticCache):
|
|||||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
|
|
||||||
if self.key_cache[layer_idx].device.type == "meta":
|
|
||||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
|
||||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
|
||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
key_states = key_states.to(k_out.dtype)
|
key_states = key_states.to(k_out.dtype)
|
||||||
@@ -1380,10 +1366,9 @@ class SlidingWindowCache(StaticCache):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for layer_idx in range(len(self.key_cache)):
|
for layer_idx in range(len(self.key_cache)):
|
||||||
if self.key_cache[layer_idx].device.type != "meta":
|
# In-place ops prevent breaking the static address
|
||||||
# In-place ops prevent breaking the static address
|
self.key_cache[layer_idx].zero_()
|
||||||
self.key_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderCache(Cache):
|
class EncoderDecoderCache(Cache):
|
||||||
@@ -1573,14 +1558,14 @@ class HybridCache(Cache):
|
|||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device` or `str`, *optional*):
|
device (`torch.device` or `str`, *optional*):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
should pass the `layer_device_map` argument instead.
|
||||||
device by default, and then moved to input device when updating.
|
|
||||||
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||||
|
checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1607,7 +1592,6 @@ class HybridCache(Cache):
|
|||||||
# is_compileable = True
|
# is_compileable = True
|
||||||
|
|
||||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -1642,7 +1626,6 @@ class HybridCache(Cache):
|
|||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||||
self.is_sliding = torch.tensor(
|
self.is_sliding = torch.tensor(
|
||||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||||
@@ -1656,11 +1639,12 @@ class HybridCache(Cache):
|
|||||||
min(config.sliding_window, max_cache_len),
|
min(config.sliding_window, max_cache_len),
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
device = torch.device(device) if device is not None else None
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
if layer_device_map is not None:
|
if layer_device_map is not None:
|
||||||
layer_device = layer_device_map[i]
|
layer_device = layer_device_map[i]
|
||||||
else:
|
else:
|
||||||
layer_device = self.device
|
layer_device = device
|
||||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache.
|
# breaks when updating the cache.
|
||||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||||
@@ -1717,9 +1701,12 @@ class HybridCache(Cache):
|
|||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
sliding_window = cache_kwargs.get("sliding_window")
|
sliding_window = cache_kwargs.get("sliding_window")
|
||||||
|
|
||||||
if self.key_cache[layer_idx].device.type == "meta":
|
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
if self.key_cache[layer_idx].device != key_states.device:
|
||||||
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
||||||
|
if self.value_cache[layer_idx].device != value_states.device:
|
||||||
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
@@ -1753,18 +1740,14 @@ class HybridCache(Cache):
|
|||||||
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
||||||
"Using the `layer_idx` argument is not supported."
|
"Using the `layer_idx` argument is not supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.key_cache[layer_idx].device.type == "meta":
|
|
||||||
return 0
|
|
||||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets the cache values while preserving the objects"""
|
"""Resets the cache values while preserving the objects"""
|
||||||
for layer_idx in range(len(self.key_cache)):
|
for layer_idx in range(len(self.key_cache)):
|
||||||
if self.key_cache[layer_idx].device.type != "meta":
|
# In-place ops prevent breaking the static address
|
||||||
# In-place ops prevent breaking the static address
|
self.key_cache[layer_idx].zero_()
|
||||||
self.key_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1789,24 +1772,6 @@ class MambaCache:
|
|||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
device (`torch.device` or `str`, *optional*):
|
device (`torch.device` or `str`, *optional*):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
|
||||||
device by default, and then moved to input device when updating.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dtype: (`torch.dtype`):
|
|
||||||
The default `dtype` used to initializing the cache.
|
|
||||||
device (`torch.device`):
|
|
||||||
The default device on which the cache was initialized.
|
|
||||||
intermediate_size: (`int`):
|
|
||||||
Model's intermediate_size taken from config.
|
|
||||||
ssm_state_size: (`int`):
|
|
||||||
Model's state_size taken from config.
|
|
||||||
conv_kernel_size: (`int`):
|
|
||||||
Model's convolution kernel size taken from config
|
|
||||||
conv_states: (`torch.Tensor`):
|
|
||||||
A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
|
|
||||||
ssm_states: (`torch.Tensor`):
|
|
||||||
A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1829,6 +1794,7 @@ class MambaCache:
|
|||||||
is_compileable = True
|
is_compileable = True
|
||||||
|
|
||||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
|
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -1847,23 +1813,23 @@ class MambaCache:
|
|||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.ssm_state_size = config.state_size
|
self.ssm_state_size = config.state_size
|
||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
|
|
||||||
self.conv_states: List[torch.Tensor] = []
|
self.conv_states: List[torch.Tensor] = []
|
||||||
self.ssm_states: List[torch.Tensor] = []
|
self.ssm_states: List[torch.Tensor] = []
|
||||||
|
device = torch.device(device) if device is not None else None
|
||||||
for _ in range(config.num_hidden_layers):
|
for _ in range(config.num_hidden_layers):
|
||||||
conv_state: torch.Tensor = torch.zeros(
|
conv_state: torch.Tensor = torch.zeros(
|
||||||
self.max_batch_size,
|
self.max_batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.conv_kernel_size,
|
self.conv_kernel_size,
|
||||||
device=self.device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
ssm_state: torch.Tensor = torch.zeros(
|
ssm_state: torch.Tensor = torch.zeros(
|
||||||
self.max_batch_size,
|
self.max_batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.ssm_state_size,
|
self.ssm_state_size,
|
||||||
device=self.device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1875,11 +1841,10 @@ class MambaCache:
|
|||||||
def update_conv_state(
|
def update_conv_state(
|
||||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.conv_states[layer_idx].device.type == "meta":
|
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||||
self.conv_states[layer_idx] = torch.zeros_like(
|
# when the cache is initialized in the forward pass (e.g. Mamba)
|
||||||
self.conv_states[layer_idx],
|
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||||
device=new_conv_state.device,
|
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||||
)
|
|
||||||
|
|
||||||
conv_state = self.conv_states[layer_idx]
|
conv_state = self.conv_states[layer_idx]
|
||||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||||
@@ -1896,10 +1861,9 @@ class MambaCache:
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for layer_idx in range(len(self.conv_states)):
|
for layer_idx in range(len(self.conv_states)):
|
||||||
if self.conv_states[layer_idx].device.type != "meta":
|
# In-place ops prevent breaking the static address
|
||||||
# In-place ops prevent breaking the static address
|
self.conv_states[layer_idx].zero_()
|
||||||
self.conv_states[layer_idx].zero_()
|
self.ssm_states[layer_idx].zero_()
|
||||||
self.ssm_states[layer_idx].zero_()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1924,33 +1888,16 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`Union[str, torch.device]`):
|
device (`Union[str, torch.device]`):
|
||||||
The device on which the cache should be initialized. Should be the same as the
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||||
layer device.
|
should pass the `layer_device_map` argument instead.
|
||||||
dtype (`torch.dtype`, *optional*):
|
dtype (`torch.dtype`, *optional*):
|
||||||
The default `dtype` to use when initializing the cache.
|
The default `dtype` to use when initializing the cache.
|
||||||
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
||||||
The device to offload to. Defaults to CPU.
|
The device to offload to. Defaults to CPU.
|
||||||
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
|
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
|
||||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||||
|
checking the associated device_map: `model.hf_device_map`.
|
||||||
Attributes:
|
|
||||||
key_cache (`List[torch.Tensor]`):
|
|
||||||
Off-loaded key cache tensors. First one will be on device, where-as the others are
|
|
||||||
off-loaded.
|
|
||||||
value_cache (`List[torch.Tensor]`):
|
|
||||||
Off-loaded value cache tensors. First one will be on device, where-as the others are
|
|
||||||
off-loaded.
|
|
||||||
max_batch_size (`int`):
|
|
||||||
The maximum batch size with which this cache can be used.
|
|
||||||
max_cache_len (`int`):
|
|
||||||
The maximum sequence length with which this cache can be used.
|
|
||||||
device (`torch.device`):
|
|
||||||
The device on which the cache is used.
|
|
||||||
offload_device (`torch.device`):
|
|
||||||
The device used to offload to.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The `dtype` used to initializing the cache.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1973,7 +1920,6 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
|
|
||||||
is_compileable = True
|
is_compileable = True
|
||||||
|
|
||||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
|||||||
@@ -483,7 +483,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
||||||
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||||
|
|
||||||
# Performances
|
# Performance
|
||||||
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
||||||
self.disable_compile = kwargs.pop("disable_compile", False)
|
self.disable_compile = kwargs.pop("disable_compile", False)
|
||||||
# Wild card
|
# Wild card
|
||||||
|
|||||||
@@ -1618,6 +1618,40 @@ class GenerationMixin:
|
|||||||
model_kwargs["cache_position"] = cache_position
|
model_kwargs["cache_position"] = cache_position
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
def _get_layer_device_map_for_cache_init(self):
|
||||||
|
"""
|
||||||
|
Taken from `dispatch_model` from accelerate.
|
||||||
|
This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||||
|
For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||||
|
"""
|
||||||
|
execution_device_map = None
|
||||||
|
|
||||||
|
if hasattr(self, "hf_device_map"):
|
||||||
|
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
|
||||||
|
main_device = "cpu"
|
||||||
|
else:
|
||||||
|
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
||||||
|
execution_device_map = {
|
||||||
|
name: main_device if device in ["cpu", "disk"] else device
|
||||||
|
for name, device in self.hf_device_map.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||||
|
if execution_device_map is None:
|
||||||
|
return None
|
||||||
|
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||||
|
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
|
||||||
|
layer_device_map = {}
|
||||||
|
for layer in execution_device_map:
|
||||||
|
for idx in range(num_hidden_layers):
|
||||||
|
if f".{idx}." in f"{layer}.":
|
||||||
|
layer_device_map[idx] = execution_device_map[layer]
|
||||||
|
break
|
||||||
|
for idx in range(num_hidden_layers):
|
||||||
|
if idx not in layer_device_map:
|
||||||
|
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||||
|
return layer_device_map
|
||||||
|
|
||||||
def _get_cache(
|
def _get_cache(
|
||||||
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||||
) -> Cache:
|
) -> Cache:
|
||||||
@@ -1664,12 +1698,14 @@ class GenerationMixin:
|
|||||||
# models. May cause trobles with non-text modalities.
|
# models. May cause trobles with non-text modalities.
|
||||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||||
|
|
||||||
|
layer_device_map = self._get_layer_device_map_for_cache_init()
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config.get_text_config(),
|
"config": self.config.get_text_config(),
|
||||||
"max_batch_size": batch_size,
|
"max_batch_size": batch_size,
|
||||||
"max_cache_len": max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
"dtype": cache_dtype,
|
"dtype": cache_dtype,
|
||||||
"device": device if cache_implementation == "offloaded_static" else None,
|
"device": device,
|
||||||
|
"layer_device_map": layer_device_map,
|
||||||
}
|
}
|
||||||
self._cache = cache_cls(**cache_kwargs)
|
self._cache = cache_cls(**cache_kwargs)
|
||||||
if requires_cross_attention_cache:
|
if requires_cross_attention_cache:
|
||||||
|
|||||||
@@ -597,11 +597,13 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
|||||||
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
if use_cache and past_key_values is None and not self.training:
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
|
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
max_batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=seq_len,
|
max_cache_len=seq_len,
|
||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
|||||||
@@ -488,11 +488,13 @@ class Cohere2Model(Gemma2Model):
|
|||||||
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
if use_cache and past_key_values is None and not self.training:
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
|
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
max_batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=seq_len,
|
max_cache_len=seq_len,
|
||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
|||||||
@@ -599,11 +599,13 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
if use_cache and past_key_values is None and not self.training:
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
|
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
max_batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=seq_len,
|
max_cache_len=seq_len,
|
||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
|||||||
@@ -437,11 +437,13 @@ class Gemma2Model(GemmaModel):
|
|||||||
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
if use_cache and past_key_values is None and not self.training:
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
|
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
max_batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
max_cache_len=seq_len,
|
max_cache_len=seq_len,
|
||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
|
|||||||
@@ -2304,45 +2304,6 @@ class GenerationTesterMixin:
|
|||||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
@is_flaky
|
|
||||||
def test_assisted_decoding_with_logits_to_keep(self):
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
|
||||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
|
||||||
if model_class._is_stateful:
|
|
||||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
|
||||||
if not hasattr(config.get_text_config(), "use_cache"):
|
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
assistant_model = model
|
|
||||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
|
||||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
|
||||||
# other methods will work as well)
|
|
||||||
generation_kwargs = {
|
|
||||||
"max_new_tokens": 10,
|
|
||||||
"do_sample": False,
|
|
||||||
"assistant_model": assistant_model,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"output_scores": True,
|
|
||||||
}
|
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
|
||||||
|
|
||||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
|
||||||
with_all_logits = model.generate(
|
|
||||||
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
|
|
||||||
)
|
|
||||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
|
||||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
|
|
||||||
|
|
||||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_inherits_generation_mixin(self):
|
def test_inherits_generation_mixin(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
from transformers import set_seed
|
from transformers import set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureStderr,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_gptq,
|
require_gptq,
|
||||||
@@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_static_cache_no_cuda_graph_skips(self):
|
||||||
|
"""
|
||||||
|
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
|
||||||
|
|
||||||
|
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
|
||||||
|
messages are being thrown to stderr?)
|
||||||
|
"""
|
||||||
|
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||||
|
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
||||||
|
with CaptureStderr() as cap:
|
||||||
|
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
||||||
|
self.assertEqual(cap.err, "")
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_static_cache_multi_gpu(self):
|
||||||
|
"""Regression test for #35164: static cache with multi-gpu"""
|
||||||
|
|
||||||
|
model_id = "google/gemma-2-2b-it"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
|
||||||
|
num_hidden_layers = 26
|
||||||
|
for i in range(num_hidden_layers):
|
||||||
|
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
device_map=device_map,
|
||||||
|
)
|
||||||
|
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
|
||||||
|
_ = model(**inputs)
|
||||||
|
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||||
|
|||||||
Reference in New Issue
Block a user