Remove deprecated batch_size parameter (#37007)

This commit is contained in:
cyyever
2025-03-27 23:01:56 +08:00
committed by GitHub
parent 4cc65e990f
commit 6cc9c8d7d1
4 changed files with 50 additions and 97 deletions

View File

@@ -1065,6 +1065,8 @@ class SinkCache(Cache):
""" """
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon. # with partially rotated position embeddings, like Phi or Persimmon.
if cache_kwargs is None:
cache_kwargs = {}
sin = cache_kwargs.get("sin") sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos") cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size") partial_rotation_size = cache_kwargs.get("partial_rotation_size")
@@ -1140,20 +1142,20 @@ class StaticCache(Cache):
Parameters: Parameters:
config (`PretrainedConfig`): config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`): max_batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
number of beams if you are running beam search number of beams if you are running beam search
max_cache_len (`int`): max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead. should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
@@ -1170,7 +1172,7 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation >>> outputs.past_key_values # access cache filled with key/values from generation
StaticCache() StaticCache()
@@ -1179,25 +1181,17 @@ class StaticCache(Cache):
is_compileable = True is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
batch_size: Optional[int] = None, max_batch_size: int,
max_cache_len: Optional[int] = None, max_cache_len: Optional[int] = None,
device: torch.device = None, device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if batch_size is not None: self.max_batch_size = max_batch_size
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
@@ -1256,6 +1250,8 @@ class StaticCache(Cache):
Return: Return:
A tuple containing the updated key and value states. A tuple containing the updated key and value states.
""" """
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
@@ -1296,14 +1292,6 @@ class StaticCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class SlidingWindowCache(StaticCache): class SlidingWindowCache(StaticCache):
""" """
@@ -1325,19 +1313,19 @@ class SlidingWindowCache(StaticCache):
Parameters: Parameters:
config (`PretrainedConfig`): config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`): max_batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. smaller batch size is used.
max_cache_len (`int`): max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead. should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
Example: Example:
@@ -1353,7 +1341,7 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation >>> outputs.past_key_values # access cache filled with key/values from generation
SlidingWindowCache() SlidingWindowCache()
@@ -1363,15 +1351,13 @@ class SlidingWindowCache(StaticCache):
is_sliding = True is_sliding = True
is_compileable = True is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
batch_size: Optional[int] = None, max_batch_size: int,
max_cache_len: Optional[int] = None, max_cache_len: Optional[int] = None,
device: torch.device = None, device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
@@ -1383,11 +1369,10 @@ class SlidingWindowCache(StaticCache):
max_cache_len = min(config.sliding_window, max_cache_len) max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__( super().__init__(
config=config, config=config,
batch_size=batch_size, max_batch_size=max_batch_size,
max_cache_len=max_cache_len, max_cache_len=max_cache_len,
device=device, device=device,
dtype=dtype, dtype=dtype,
max_batch_size=max_batch_size,
layer_device_map=layer_device_map, layer_device_map=layer_device_map,
) )
@@ -1397,7 +1382,9 @@ class SlidingWindowCache(StaticCache):
value_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None, cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
@@ -1631,19 +1618,19 @@ class HybridCache(Cache):
Parameters: Parameters:
config (`PretrainedConfig): config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`): max_batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. smaller batch size is used.
max_cache_len (`int`): max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used. The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*): device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead. should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.float32`): dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`. checking the associated device_map: `model.hf_device_map`.
Example: Example:
@@ -1659,7 +1646,7 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation >>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache() HybridCache()
@@ -1670,23 +1657,16 @@ class HybridCache(Cache):
# ALL changes from the PR that commented the line below when reactivating it. # ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True # is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
batch_size: Optional[int] = None, max_batch_size: int,
max_cache_len: Optional[int] = None, max_cache_len: Optional[int] = None,
device: Union[torch.device, str] = None, device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
@@ -1694,7 +1674,7 @@ class HybridCache(Cache):
"config and it's not set to None." "config and it's not set to None."
) )
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len
self.max_batch_size = batch_size or max_batch_size self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -1718,7 +1698,7 @@ class HybridCache(Cache):
min(config.sliding_window, max_cache_len), min(config.sliding_window, max_cache_len),
self.head_dim, self.head_dim,
) )
device = torch.device(device) if device is not None else None device = torch.device(device) if device is not None and isinstance(device, str) else None
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
if layer_device_map is not None: if layer_device_map is not None:
layer_device = layer_device_map[i] layer_device = layer_device_map[i]
@@ -1776,7 +1756,9 @@ class HybridCache(Cache):
value_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None, cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window") sliding_window = cache_kwargs.get("sliding_window")
@@ -1828,14 +1810,6 @@ class HybridCache(Cache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class MambaCache: class MambaCache:
""" """
@@ -1844,9 +1818,8 @@ class MambaCache:
Arguments: Arguments:
config (`PretrainedConfig): config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache. The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`): max_batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer. The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*): device (`torch.device` or `str`, *optional*):
@@ -1863,7 +1836,7 @@ class MambaCache:
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward >>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values >>> outputs.past_key_values
MambaCache() MambaCache()
@@ -1872,23 +1845,16 @@ class MambaCache:
is_compileable = True is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly # TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
batch_size: Optional[int] = None, max_batch_size: int,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None, device: Union[torch.device, str, None] = None,
max_batch_size: Optional[int] = None,
): ):
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size self.max_batch_size = max_batch_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
@@ -1944,14 +1910,6 @@ class MambaCache:
self.conv_states[layer_idx].zero_() self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_() self.ssm_states[layer_idx].zero_()
@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size
class OffloadedStaticCache(StaticCache): class OffloadedStaticCache(StaticCache):
""" """

View File

@@ -422,7 +422,7 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model.eval() model.eval()
# Create cache with float32 dtype # Create cache with float32 dtype
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)
# If code is correct, no error occurs and test passes # If code is correct, no error occurs and test passes
outputs = model( outputs = model(

View File

@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
return random_keys, random_values return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32) mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update( cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update( cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )
@@ -167,7 +167,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update( cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
) )

View File

@@ -74,11 +74,6 @@ OBJECTS_TO_IGNORE = [
"TFSequenceSummary", "TFSequenceSummary",
"TFBertTokenizer", "TFBertTokenizer",
"TFGPT2Tokenizer", "TFGPT2Tokenizer",
# Going through an argument deprecation cycle, remove after v4.46
"HybridCache",
"MambaCache",
"SlidingWindowCache",
"StaticCache",
# Missing arguments in the docstring # Missing arguments in the docstring
"ASTFeatureExtractor", "ASTFeatureExtractor",
"AlbertModel", "AlbertModel",