|
|
|
@@ -1065,6 +1065,8 @@ class SinkCache(Cache):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
|
|
|
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
|
|
|
# with partially rotated position embeddings, like Phi or Persimmon.
|
|
|
|
# with partially rotated position embeddings, like Phi or Persimmon.
|
|
|
|
|
|
|
|
if cache_kwargs is None:
|
|
|
|
|
|
|
|
cache_kwargs = {}
|
|
|
|
sin = cache_kwargs.get("sin")
|
|
|
|
sin = cache_kwargs.get("sin")
|
|
|
|
cos = cache_kwargs.get("cos")
|
|
|
|
cos = cache_kwargs.get("cos")
|
|
|
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
|
|
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
|
|
|
@@ -1140,20 +1142,20 @@ class StaticCache(Cache):
|
|
|
|
Parameters:
|
|
|
|
Parameters:
|
|
|
|
config (`PretrainedConfig`):
|
|
|
|
config (`PretrainedConfig`):
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
batch_size (`int`):
|
|
|
|
max_batch_size (`int`):
|
|
|
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
|
|
|
|
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
|
|
|
|
number of beams if you are running beam search
|
|
|
|
number of beams if you are running beam search
|
|
|
|
max_cache_len (`int`):
|
|
|
|
max_cache_len (`int`, *optional*):
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
device (`torch.device` or `str`):
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
|
|
|
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
|
|
|
and the model is split between different gpus. You can know which layers mapped to which device by
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1170,7 +1172,7 @@ class StaticCache(Cache):
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
StaticCache()
|
|
|
|
StaticCache()
|
|
|
|
@@ -1179,25 +1181,17 @@ class StaticCache(Cache):
|
|
|
|
|
|
|
|
|
|
|
|
is_compileable = True
|
|
|
|
is_compileable = True
|
|
|
|
|
|
|
|
|
|
|
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
max_batch_size: int,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
device: torch.device = None,
|
|
|
|
device: Union[torch.device, str, None] = None,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
if batch_size is not None:
|
|
|
|
self.max_batch_size = max_batch_size
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.max_batch_size = batch_size or max_batch_size
|
|
|
|
|
|
|
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
|
|
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
|
|
|
|
|
|
|
|
|
|
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
|
|
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
|
|
|
@@ -1256,6 +1250,8 @@ class StaticCache(Cache):
|
|
|
|
Return:
|
|
|
|
Return:
|
|
|
|
A tuple containing the updated key and value states.
|
|
|
|
A tuple containing the updated key and value states.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
if cache_kwargs is None:
|
|
|
|
|
|
|
|
cache_kwargs = {}
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
k_out = self.key_cache[layer_idx]
|
|
|
|
k_out = self.key_cache[layer_idx]
|
|
|
|
v_out = self.value_cache[layer_idx]
|
|
|
|
v_out = self.value_cache[layer_idx]
|
|
|
|
@@ -1296,14 +1292,6 @@ class StaticCache(Cache):
|
|
|
|
self.key_cache[layer_idx].zero_()
|
|
|
|
self.key_cache[layer_idx].zero_()
|
|
|
|
self.value_cache[layer_idx].zero_()
|
|
|
|
self.value_cache[layer_idx].zero_()
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def batch_size(self):
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.max_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SlidingWindowCache(StaticCache):
|
|
|
|
class SlidingWindowCache(StaticCache):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@@ -1325,19 +1313,19 @@ class SlidingWindowCache(StaticCache):
|
|
|
|
Parameters:
|
|
|
|
Parameters:
|
|
|
|
config (`PretrainedConfig`):
|
|
|
|
config (`PretrainedConfig`):
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
batch_size (`int`):
|
|
|
|
max_batch_size (`int`):
|
|
|
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
smaller batch size is used.
|
|
|
|
smaller batch size is used.
|
|
|
|
max_cache_len (`int`):
|
|
|
|
max_cache_len (`int`, *optional*):
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
device (`torch.device` or `str`):
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
|
|
|
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
|
|
|
and the model is split between different gpus. You can know which layers mapped to which device by
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
Example:
|
|
|
|
@@ -1353,7 +1341,7 @@ class SlidingWindowCache(StaticCache):
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
SlidingWindowCache()
|
|
|
|
SlidingWindowCache()
|
|
|
|
@@ -1363,15 +1351,13 @@ class SlidingWindowCache(StaticCache):
|
|
|
|
is_sliding = True
|
|
|
|
is_sliding = True
|
|
|
|
is_compileable = True
|
|
|
|
is_compileable = True
|
|
|
|
|
|
|
|
|
|
|
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
max_batch_size: int,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
device: torch.device = None,
|
|
|
|
device: Union[torch.device, str, None] = None,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
|
|
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
|
|
|
@@ -1383,11 +1369,10 @@ class SlidingWindowCache(StaticCache):
|
|
|
|
max_cache_len = min(config.sliding_window, max_cache_len)
|
|
|
|
max_cache_len = min(config.sliding_window, max_cache_len)
|
|
|
|
super().__init__(
|
|
|
|
super().__init__(
|
|
|
|
config=config,
|
|
|
|
config=config,
|
|
|
|
batch_size=batch_size,
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
max_cache_len=max_cache_len,
|
|
|
|
max_cache_len=max_cache_len,
|
|
|
|
device=device,
|
|
|
|
device=device,
|
|
|
|
dtype=dtype,
|
|
|
|
dtype=dtype,
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
|
|
|
layer_device_map=layer_device_map,
|
|
|
|
layer_device_map=layer_device_map,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1397,7 +1382,9 @@ class SlidingWindowCache(StaticCache):
|
|
|
|
value_states: torch.Tensor,
|
|
|
|
value_states: torch.Tensor,
|
|
|
|
layer_idx: int,
|
|
|
|
layer_idx: int,
|
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
) -> Tuple[torch.Tensor]:
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
if cache_kwargs is None:
|
|
|
|
|
|
|
|
cache_kwargs = {}
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
k_out = self.key_cache[layer_idx]
|
|
|
|
k_out = self.key_cache[layer_idx]
|
|
|
|
v_out = self.value_cache[layer_idx]
|
|
|
|
v_out = self.value_cache[layer_idx]
|
|
|
|
@@ -1631,19 +1618,19 @@ class HybridCache(Cache):
|
|
|
|
Parameters:
|
|
|
|
Parameters:
|
|
|
|
config (`PretrainedConfig):
|
|
|
|
config (`PretrainedConfig):
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
batch_size (`int`):
|
|
|
|
max_batch_size (`int`):
|
|
|
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
smaller batch size is used.
|
|
|
|
smaller batch size is used.
|
|
|
|
max_cache_len (`int`):
|
|
|
|
max_cache_len (`int`, *optional*):
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
The maximum sequence length with which the model will be used.
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
should pass the `layer_device_map` argument instead.
|
|
|
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
|
|
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
|
|
|
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
|
|
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
|
|
|
and the model is split between different gpus. You can know which layers mapped to which device by
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
checking the associated device_map: `model.hf_device_map`.
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
Example:
|
|
|
|
@@ -1659,7 +1646,7 @@ class HybridCache(Cache):
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
|
|
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
|
|
HybridCache()
|
|
|
|
HybridCache()
|
|
|
|
@@ -1670,23 +1657,16 @@ class HybridCache(Cache):
|
|
|
|
# ALL changes from the PR that commented the line below when reactivating it.
|
|
|
|
# ALL changes from the PR that commented the line below when reactivating it.
|
|
|
|
# is_compileable = True
|
|
|
|
# is_compileable = True
|
|
|
|
|
|
|
|
|
|
|
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
max_batch_size: int,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
max_cache_len: Optional[int] = None,
|
|
|
|
device: Union[torch.device, str] = None,
|
|
|
|
device: Union[torch.device, str, None] = None,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
dtype: torch.dtype = torch.float32,
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
if batch_size is not None:
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
|
|
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
|
|
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
|
|
|
@@ -1694,7 +1674,7 @@ class HybridCache(Cache):
|
|
|
|
"config and it's not set to None."
|
|
|
|
"config and it's not set to None."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.max_cache_len = max_cache_len
|
|
|
|
self.max_cache_len = max_cache_len
|
|
|
|
self.max_batch_size = batch_size or max_batch_size
|
|
|
|
self.max_batch_size = max_batch_size
|
|
|
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
|
|
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
|
|
|
self.head_dim = (
|
|
|
|
self.head_dim = (
|
|
|
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
|
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
|
|
@@ -1718,7 +1698,7 @@ class HybridCache(Cache):
|
|
|
|
min(config.sliding_window, max_cache_len),
|
|
|
|
min(config.sliding_window, max_cache_len),
|
|
|
|
self.head_dim,
|
|
|
|
self.head_dim,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
device = torch.device(device) if device is not None else None
|
|
|
|
device = torch.device(device) if device is not None and isinstance(device, str) else None
|
|
|
|
for i in range(config.num_hidden_layers):
|
|
|
|
for i in range(config.num_hidden_layers):
|
|
|
|
if layer_device_map is not None:
|
|
|
|
if layer_device_map is not None:
|
|
|
|
layer_device = layer_device_map[i]
|
|
|
|
layer_device = layer_device_map[i]
|
|
|
|
@@ -1776,7 +1756,9 @@ class HybridCache(Cache):
|
|
|
|
value_states: torch.Tensor,
|
|
|
|
value_states: torch.Tensor,
|
|
|
|
layer_idx: int,
|
|
|
|
layer_idx: int,
|
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
) -> Tuple[torch.Tensor]:
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
if cache_kwargs is None:
|
|
|
|
|
|
|
|
cache_kwargs = {}
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
cache_position = cache_kwargs.get("cache_position")
|
|
|
|
sliding_window = cache_kwargs.get("sliding_window")
|
|
|
|
sliding_window = cache_kwargs.get("sliding_window")
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1828,14 +1810,6 @@ class HybridCache(Cache):
|
|
|
|
self.key_cache[layer_idx].zero_()
|
|
|
|
self.key_cache[layer_idx].zero_()
|
|
|
|
self.value_cache[layer_idx].zero_()
|
|
|
|
self.value_cache[layer_idx].zero_()
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def batch_size(self):
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.max_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MambaCache:
|
|
|
|
class MambaCache:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@@ -1844,9 +1818,8 @@ class MambaCache:
|
|
|
|
Arguments:
|
|
|
|
Arguments:
|
|
|
|
config (`PretrainedConfig):
|
|
|
|
config (`PretrainedConfig):
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
|
batch_size (`int`):
|
|
|
|
max_batch_size (`int`):
|
|
|
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
|
|
|
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
|
|
|
|
smaller batch size is used.
|
|
|
|
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
|
|
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
The default `dtype` to use when initializing the layer.
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
device (`torch.device` or `str`, *optional*):
|
|
|
|
@@ -1863,7 +1836,7 @@ class MambaCache:
|
|
|
|
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
|
|
|
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
|
|
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
|
|
|
|
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
|
|
>>> outputs.past_key_values
|
|
|
|
>>> outputs.past_key_values
|
|
|
|
MambaCache()
|
|
|
|
MambaCache()
|
|
|
|
@@ -1872,23 +1845,16 @@ class MambaCache:
|
|
|
|
|
|
|
|
|
|
|
|
is_compileable = True
|
|
|
|
is_compileable = True
|
|
|
|
|
|
|
|
|
|
|
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
|
|
|
|
|
|
|
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
|
|
|
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
config: PretrainedConfig,
|
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
max_batch_size: int,
|
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
device: Optional[Union[torch.device, str]] = None,
|
|
|
|
device: Union[torch.device, str, None] = None,
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if batch_size is not None:
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.dtype = dtype
|
|
|
|
self.dtype = dtype
|
|
|
|
self.max_batch_size = batch_size or max_batch_size
|
|
|
|
self.max_batch_size = max_batch_size
|
|
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.ssm_state_size = config.state_size
|
|
|
|
self.ssm_state_size = config.state_size
|
|
|
|
self.conv_kernel_size = config.conv_kernel
|
|
|
|
self.conv_kernel_size = config.conv_kernel
|
|
|
|
@@ -1944,14 +1910,6 @@ class MambaCache:
|
|
|
|
self.conv_states[layer_idx].zero_()
|
|
|
|
self.conv_states[layer_idx].zero_()
|
|
|
|
self.ssm_states[layer_idx].zero_()
|
|
|
|
self.ssm_states[layer_idx].zero_()
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def batch_size(self):
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
|
|
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
|
|
|
|
|
|
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.max_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OffloadedStaticCache(StaticCache):
|
|
|
|
class OffloadedStaticCache(StaticCache):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|