Cache: use batch_size instead of max_batch_size (#32657)
* more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
|
|||||||
|
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
max_batch_size=1,
|
batch_size=1,
|
||||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||||
device=model.device,
|
device=model.device,
|
||||||
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
|
|||||||
batch_size, seq_length = inputs["input_ids"].shape
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||||
)
|
)
|
||||||
cache_position = torch.arange(seq_length, device=torch_device)
|
cache_position = torch.arange(seq_length, device=torch_device)
|
||||||
generated_ids = torch.zeros(
|
generated_ids = torch.zeros(
|
||||||
|
|||||||
@@ -977,13 +977,14 @@ class StaticCache(Cache):
|
|||||||
Parameters:
|
Parameters:
|
||||||
config (`PretrainedConfig`):
|
config (`PretrainedConfig`):
|
||||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
max_batch_size (`int`):
|
batch_size (`int`):
|
||||||
The maximum batch size with which the model will be used.
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||||
|
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
|
||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device`):
|
device (`torch.device` or `str`):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (*optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -999,22 +1000,37 @@ class StaticCache(Cache):
|
|||||||
>>> # Prepare a cache class and pass it to model's forward
|
>>> # Prepare a cache class and pass it to model's forward
|
||||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||||
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||||
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
batch_size: int = None,
|
||||||
|
max_cache_len: int = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
max_batch_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_batch_size = max_batch_size
|
if max_batch_size is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.batch_size = batch_size or max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
self.head_dim = (
|
self.head_dim = (
|
||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
self.dtype = dtype
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads
|
config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
@@ -1024,7 +1040,7 @@ class StaticCache(Cache):
|
|||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
for idx in range(config.num_hidden_layers):
|
for idx in range(config.num_hidden_layers):
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||||
@@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
|
|||||||
Parameters:
|
Parameters:
|
||||||
config (`PretrainedConfig`):
|
config (`PretrainedConfig`):
|
||||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
max_batch_size (`int`):
|
batch_size (`int`):
|
||||||
The maximum batch size with which the model will be used.
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||||
|
smaller batch size is used.
|
||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device`):
|
device (`torch.device` or `str`):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (*optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
|
|||||||
>>> # Prepare a cache class and pass it to model's forward
|
>>> # Prepare a cache class and pass it to model's forward
|
||||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||||
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||||
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
batch_size: int = None,
|
||||||
|
max_cache_len: int = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
max_batch_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1168,7 +1194,12 @@ class SlidingWindowCache(StaticCache):
|
|||||||
)
|
)
|
||||||
max_cache_len = min(config.sliding_window, max_cache_len)
|
max_cache_len = min(config.sliding_window, max_cache_len)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
|
config=config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_cache_len=max_cache_len,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@@ -1407,13 +1438,14 @@ class HybridCache(Cache):
|
|||||||
Parameters:
|
Parameters:
|
||||||
config (`PretrainedConfig):
|
config (`PretrainedConfig):
|
||||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
max_batch_size (`int`):
|
batch_size (`int`):
|
||||||
The maximum batch size with which the model will be used.
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||||
|
smaller batch size is used.
|
||||||
max_cache_len (`int`):
|
max_cache_len (`int`):
|
||||||
The maximum sequence length with which the model will be used.
|
The maximum sequence length with which the model will be used.
|
||||||
device (`torch.device`, *optional*, defaults to `"cpu"`):
|
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (*optional*, defaults to `torch.float32`):
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -1429,14 +1461,28 @@ class HybridCache(Cache):
|
|||||||
>>> # Prepare a cache class and pass it to model's forward
|
>>> # Prepare a cache class and pass it to model's forward
|
||||||
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
||||||
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
||||||
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||||
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
batch_size: int = None,
|
||||||
|
max_cache_len: int = None,
|
||||||
|
device: Union[torch.device, str] = "cpu",
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
max_batch_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if max_batch_size is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||||
|
)
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||||
@@ -1444,13 +1490,13 @@ class HybridCache(Cache):
|
|||||||
"config and it's not set to None."
|
"config and it's not set to None."
|
||||||
)
|
)
|
||||||
self.max_cache_len = max_cache_len
|
self.max_cache_len = max_cache_len
|
||||||
self.max_batch_size = max_batch_size
|
self.batch_size = batch_size or max_batch_size
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
self.head_dim = (
|
self.head_dim = (
|
||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
self.dtype = dtype
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
@@ -1459,9 +1505,9 @@ class HybridCache(Cache):
|
|||||||
)
|
)
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
||||||
sliding_cache_shape = (
|
sliding_cache_shape = (
|
||||||
max_batch_size,
|
self.batch_size,
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
min(config.sliding_window, max_cache_len),
|
min(config.sliding_window, max_cache_len),
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -1564,11 +1610,12 @@ class MambaCache:
|
|||||||
Arguments:
|
Arguments:
|
||||||
config (`PretrainedConfig):
|
config (`PretrainedConfig):
|
||||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||||
max_batch_size (`int`):
|
batch_size (`int`):
|
||||||
The maximum batch size with which the model will be used.
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||||
dtype (*optional*, defaults to `torch.float16`):
|
smaller batch size is used.
|
||||||
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
device (`torch.device`, *optional*):
|
device (`torch.device` or `str`, *optional*):
|
||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
@@ -1596,29 +1643,35 @@ class MambaCache:
|
|||||||
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
|
||||||
|
|
||||||
>>> # Prepare a cache class and pass it to model's forward
|
>>> # Prepare a cache class and pass it to model's forward
|
||||||
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
|
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
|
||||||
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||||
>>> past_kv = outputs.past_key_values
|
>>> past_kv = outputs.past_key_values
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
max_batch_size: int,
|
batch_size: int = None,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
device: Optional[str] = None,
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
**kwargs,
|
max_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
if max_batch_size is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
|
"v4.46. Use the more precisely named 'batch_size' argument instead."
|
||||||
|
)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.max_batch_size = max_batch_size
|
self.batch_size = batch_size or max_batch_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.ssm_state_size = config.state_size
|
self.ssm_state_size = config.state_size
|
||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
|
|
||||||
self.conv_states: torch.Tensor = torch.zeros(
|
self.conv_states: torch.Tensor = torch.zeros(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
self.max_batch_size,
|
self.batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.conv_kernel_size,
|
self.conv_kernel_size,
|
||||||
device=device,
|
device=device,
|
||||||
@@ -1626,7 +1679,7 @@ class MambaCache:
|
|||||||
)
|
)
|
||||||
self.ssm_states: torch.Tensor = torch.zeros(
|
self.ssm_states: torch.Tensor = torch.zeros(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
self.max_batch_size,
|
self.batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.ssm_state_size,
|
self.ssm_state_size,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@@ -1426,7 +1426,7 @@ class GenerationMixin:
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _get_cache(
|
def _get_cache(
|
||||||
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||||
) -> Cache:
|
) -> Cache:
|
||||||
"""
|
"""
|
||||||
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
||||||
@@ -1448,7 +1448,7 @@ class GenerationMixin:
|
|||||||
need_new_cache = (
|
need_new_cache = (
|
||||||
not hasattr(self, "_cache")
|
not hasattr(self, "_cache")
|
||||||
or (not isinstance(cache_to_check, cache_cls))
|
or (not isinstance(cache_to_check, cache_cls))
|
||||||
or cache_to_check.max_batch_size != max_batch_size
|
or cache_to_check.batch_size != batch_size
|
||||||
)
|
)
|
||||||
if cache_implementation != "mamba":
|
if cache_implementation != "mamba":
|
||||||
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
||||||
@@ -1473,7 +1473,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config,
|
"config": self.config,
|
||||||
"max_batch_size": max_batch_size,
|
"batch_size": batch_size,
|
||||||
"max_cache_len": max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": cache_dtype,
|
"dtype": cache_dtype,
|
||||||
@@ -1812,7 +1812,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
model_kwargs[cache_name] = self._get_cache(
|
model_kwargs[cache_name] = self._get_cache(
|
||||||
cache_implementation=generation_config.cache_implementation,
|
cache_implementation=generation_config.cache_implementation,
|
||||||
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
||||||
max_cache_len=generation_config.max_length,
|
max_cache_len=generation_config.max_length,
|
||||||
device=device,
|
device=device,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
|||||||
@@ -818,7 +818,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
max_batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_cache_len=seq_len,
|
max_cache_len=seq_len,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
|
|||||||
@@ -1040,7 +1040,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
max_batch_size=1,
|
batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
@@ -1088,7 +1088,7 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
max_batch_size=1,
|
batch_size=1,
|
||||||
max_cache_len=max_cache_len,
|
max_cache_len=max_cache_len,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
|||||||
@@ -47,12 +47,12 @@ if is_torch_available():
|
|||||||
end_of_text_token = 32000
|
end_of_text_token = 32000
|
||||||
|
|
||||||
class Phi3MiniWithStaticCache(torch.nn.Module):
|
class Phi3MiniWithStaticCache(torch.nn.Module):
|
||||||
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
|
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cache = StaticCache(
|
self.cache = StaticCache(
|
||||||
config=model.config,
|
config=model.config,
|
||||||
max_batch_size=max_batch_size,
|
batch_size=batch_size,
|
||||||
max_cache_len=max_seq_len,
|
max_cache_len=max_seq_len,
|
||||||
device=self.model.device,
|
device=self.model.device,
|
||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class AqlmTest(unittest.TestCase):
|
|||||||
# Setup static KV cache for generation
|
# Setup static KV cache for generation
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
config=self.quantized_model.config,
|
config=self.quantized_model.config,
|
||||||
max_batch_size=1,
|
batch_size=1,
|
||||||
max_cache_len=seq_length + self.max_new_tokens + 1,
|
max_cache_len=seq_length + self.max_new_tokens + 1,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
dtype=self.quantized_model.config._pre_quantization_dtype,
|
dtype=self.quantized_model.config._pre_quantization_dtype,
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
return random_keys, random_values
|
return random_keys, random_values
|
||||||
|
|
||||||
mha_config = LlamaConfig(num_attention_heads=32)
|
mha_config = LlamaConfig(num_attention_heads=32)
|
||||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
cached_keys, cached_values = mha_static_cache.update(
|
cached_keys, cached_values = mha_static_cache.update(
|
||||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||||
)
|
)
|
||||||
@@ -153,7 +153,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||||
|
|
||||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
cached_keys, cached_values = gqa_static_cache.update(
|
cached_keys, cached_values = gqa_static_cache.update(
|
||||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||||
)
|
)
|
||||||
@@ -161,7 +161,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||||
|
|
||||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device)
|
||||||
cached_keys, cached_values = mqa_static_cache.update(
|
cached_keys, cached_values = mqa_static_cache.update(
|
||||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||||
)
|
)
|
||||||
@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
max_batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
"google/gemma-2b",
|
"google/gemma-2b",
|
||||||
@@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.model = model
|
self.model = model
|
||||||
self.static_cache = StaticCache(
|
self.static_cache = StaticCache(
|
||||||
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
|
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
||||||
|
|||||||
@@ -75,6 +75,11 @@ OBJECTS_TO_IGNORE = [
|
|||||||
"TFSequenceSummary",
|
"TFSequenceSummary",
|
||||||
"TFBertTokenizer",
|
"TFBertTokenizer",
|
||||||
"TFGPT2Tokenizer",
|
"TFGPT2Tokenizer",
|
||||||
|
# Going through an argument deprecation cycle, remove after v4.46
|
||||||
|
"HybridCache",
|
||||||
|
"MambaCache",
|
||||||
|
"SlidingWindowCache",
|
||||||
|
"StaticCache",
|
||||||
# Missing arguments in the docstring
|
# Missing arguments in the docstring
|
||||||
"ASTFeatureExtractor",
|
"ASTFeatureExtractor",
|
||||||
"AlbertModel",
|
"AlbertModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user