Cache: use batch_size instead of max_batch_size (#32657)

* more precise name

* better docstrings

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-16 11:48:45 +01:00
committed by GitHub
parent 8f9fa3b081
commit cf32ee1753
9 changed files with 112 additions and 54 deletions

View File

@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, config=model.config,
max_batch_size=1, batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases # If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device, device=model.device,
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
batch_size, seq_length = inputs["input_ids"].shape batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad(): with torch.no_grad():
past_key_values = StaticCache( past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
) )
cache_position = torch.arange(seq_length, device=torch_device) cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros( generated_ids = torch.zeros(

View File

@@ -977,13 +977,14 @@ class StaticCache(Cache):
Parameters: Parameters:
config (`PretrainedConfig`): config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`): batch_size (`int`):
The maximum batch size with which the model will be used. The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`): device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
Example: Example:
@@ -999,22 +1000,37 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
``` ```
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
self.max_batch_size = max_batch_size if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
self.batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype if dtype is not None else torch.float32 self.dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
@@ -1024,7 +1040,7 @@ class StaticCache(Cache):
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead. # Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers): for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
@@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
Parameters: Parameters:
config (`PretrainedConfig`): config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`): batch_size (`int`):
The maximum batch size with which the model will be used. The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`): device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
Example: Example:
@@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
``` ```
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
@@ -1168,7 +1194,12 @@ class SlidingWindowCache(StaticCache):
) )
max_cache_len = min(config.sliding_window, max_cache_len) max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__( super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype config=config,
batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
) )
def update( def update(
@@ -1407,13 +1438,14 @@ class HybridCache(Cache):
Parameters: Parameters:
config (`PretrainedConfig): config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`): batch_size (`int`):
The maximum batch size with which the model will be used. The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`): max_cache_len (`int`):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device`, *optional*, defaults to `"cpu"`): device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`): dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
Example: Example:
@@ -1429,14 +1461,28 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
``` ```
""" """
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
@@ -1444,13 +1490,13 @@ class HybridCache(Cache):
"config and it's not set to None." "config and it's not set to None."
) )
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size self.batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype if dtype is not None else torch.float32 self.dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
) )
@@ -1459,9 +1505,9 @@ class HybridCache(Cache):
) )
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = ( sliding_cache_shape = (
max_batch_size, self.batch_size,
self.num_key_value_heads, self.num_key_value_heads,
min(config.sliding_window, max_cache_len), min(config.sliding_window, max_cache_len),
self.head_dim, self.head_dim,
@@ -1564,11 +1610,12 @@ class MambaCache:
Arguments: Arguments:
config (`PretrainedConfig): config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`): batch_size (`int`):
The maximum batch size with which the model will be used. The batch size with which the model will be used. Note that a new instance must be instantiated if a
dtype (*optional*, defaults to `torch.float16`): smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
device (`torch.device`, *optional*): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer. The device on which the cache should be initialized. Should be the same as the layer.
Attributes: Attributes:
@@ -1596,29 +1643,35 @@ class MambaCache:
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv = outputs.past_key_values >>> past_kv = outputs.past_key_values
``` ```
""" """
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
max_batch_size: int, batch_size: int = None,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
device: Optional[str] = None, device: Optional[Union[torch.device, str]] = None,
**kwargs, max_batch_size: Optional[int] = None,
): ):
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
self.dtype = dtype self.dtype = dtype
self.max_batch_size = max_batch_size self.batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
self.conv_states: torch.Tensor = torch.zeros( self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers, config.num_hidden_layers,
self.max_batch_size, self.batch_size,
self.intermediate_size, self.intermediate_size,
self.conv_kernel_size, self.conv_kernel_size,
device=device, device=device,
@@ -1626,7 +1679,7 @@ class MambaCache:
) )
self.ssm_states: torch.Tensor = torch.zeros( self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers, config.num_hidden_layers,
self.max_batch_size, self.batch_size,
self.intermediate_size, self.intermediate_size,
self.ssm_state_size, self.ssm_state_size,
device=device, device=device,

View File

@@ -1426,7 +1426,7 @@ class GenerationMixin:
return model_kwargs return model_kwargs
def _get_cache( def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache: ) -> Cache:
""" """
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
@@ -1448,7 +1448,7 @@ class GenerationMixin:
need_new_cache = ( need_new_cache = (
not hasattr(self, "_cache") not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls)) or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size or cache_to_check.batch_size != batch_size
) )
if cache_implementation != "mamba": if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
@@ -1473,7 +1473,7 @@ class GenerationMixin:
cache_kwargs = { cache_kwargs = {
"config": self.config, "config": self.config,
"max_batch_size": max_batch_size, "batch_size": batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": device, "device": device,
"dtype": cache_dtype, "dtype": cache_dtype,
@@ -1812,7 +1812,7 @@ class GenerationMixin:
) )
model_kwargs[cache_name] = self._get_cache( model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation, cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length, max_cache_len=generation_config.max_length,
device=device, device=device,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,

View File

@@ -818,7 +818,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
batch_size, seq_len, _ = inputs_embeds.shape batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache( past_key_values = HybridCache(
self.config, self.config,
max_batch_size=batch_size, batch_size=batch_size,
max_cache_len=seq_len, max_cache_len=seq_len,
device=self.device, device=self.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,

View File

@@ -1040,7 +1040,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
max_batch_size=1, batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,
@@ -1088,7 +1088,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.model.config, config=self.model.config,
max_batch_size=1, batch_size=1,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=torch_device, device=torch_device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@@ -47,12 +47,12 @@ if is_torch_available():
end_of_text_token = 32000 end_of_text_token = 32000
class Phi3MiniWithStaticCache(torch.nn.Module): class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int): def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
super().__init__() super().__init__()
self.model = model self.model = model
self.cache = StaticCache( self.cache = StaticCache(
config=model.config, config=model.config,
max_batch_size=max_batch_size, batch_size=batch_size,
max_cache_len=max_seq_len, max_cache_len=max_seq_len,
device=self.model.device, device=self.model.device,
dtype=self.model.dtype, dtype=self.model.dtype,

View File

@@ -216,7 +216,7 @@ class AqlmTest(unittest.TestCase):
# Setup static KV cache for generation # Setup static KV cache for generation
past_key_values = StaticCache( past_key_values = StaticCache(
config=self.quantized_model.config, config=self.quantized_model.config,
max_batch_size=1, batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1, max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device, device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype, dtype=self.quantized_model.config._pre_quantization_dtype,

View File

@@ -145,7 +145,7 @@ class CacheTest(unittest.TestCase):
return random_keys, random_values return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32) mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update( cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
@@ -153,7 +153,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update( cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
@@ -161,7 +161,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update( cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
device = "cpu" device = "cpu"
dtype = torch.float32 dtype = torch.float32
max_batch_size = 1 batch_size = 1
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
"google/gemma-2b", "google/gemma-2b",
@@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase):
self.config = config self.config = config
self.model = model self.model = model
self.static_cache = StaticCache( self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
) )
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):

View File

@@ -75,6 +75,11 @@ OBJECTS_TO_IGNORE = [
"TFSequenceSummary", "TFSequenceSummary",
"TFBertTokenizer", "TFBertTokenizer",
"TFGPT2Tokenizer", "TFGPT2Tokenizer",
# Going through an argument deprecation cycle, remove after v4.46
"HybridCache",
"MambaCache",
"SlidingWindowCache",
"StaticCache",
# Missing arguments in the docstring # Missing arguments in the docstring
"ASTFeatureExtractor", "ASTFeatureExtractor",
"AlbertModel", "AlbertModel",