Cache: use batch_size instead of max_batch_size (#32657)

* more precise name

* better docstrings

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-16 11:48:45 +01:00
committed by GitHub
parent 8f9fa3b081
commit cf32ee1753
9 changed files with 112 additions and 54 deletions

View File

@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(

View File

@@ -977,13 +977,14 @@ class StaticCache(Cache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
@@ -999,22 +1000,37 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
self.batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
@@ -1024,7 +1040,7 @@ class StaticCache(Cache):
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
@@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
@@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
@@ -1168,7 +1194,12 @@ class SlidingWindowCache(StaticCache):
)
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
config=config,
batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
)
def update(
@@ -1407,13 +1438,14 @@ class HybridCache(Cache):
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`, *optional*, defaults to `"cpu"`):
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
@@ -1429,14 +1461,28 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
@@ -1444,13 +1490,13 @@ class HybridCache(Cache):
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
self.batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
@@ -1459,9 +1505,9 @@ class HybridCache(Cache):
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
@@ -1564,11 +1610,12 @@ class MambaCache:
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
dtype (*optional*, defaults to `torch.float16`):
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device`, *optional*):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Attributes:
@@ -1596,29 +1643,35 @@ class MambaCache:
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv = outputs.past_key_values
```
"""
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
batch_size: int = None,
dtype: torch.dtype = torch.float16,
device: Optional[str] = None,
**kwargs,
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
):
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = max_batch_size
self.batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
@@ -1626,7 +1679,7 @@ class MambaCache:
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,

View File

@@ -1426,7 +1426,7 @@ class GenerationMixin:
return model_kwargs
def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
@@ -1448,7 +1448,7 @@ class GenerationMixin:
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.batch_size != batch_size
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
@@ -1473,7 +1473,7 @@ class GenerationMixin:
cache_kwargs = {
"config": self.config,
"max_batch_size": max_batch_size,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
@@ -1812,7 +1812,7 @@ class GenerationMixin:
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,

View File

@@ -818,7 +818,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,

View File

@@ -1040,7 +1040,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
@@ -1088,7 +1088,7 @@ class Mask4DTestHard(unittest.TestCase):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,

View File

@@ -47,12 +47,12 @@ if is_torch_available():
end_of_text_token = 32000
class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
batch_size=batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,

View File

@@ -216,7 +216,7 @@ class AqlmTest(unittest.TestCase):
# Setup static KV cache for generation
past_key_values = StaticCache(
config=self.quantized_model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,

View File

@@ -145,7 +145,7 @@ class CacheTest(unittest.TestCase):
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@@ -153,7 +153,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@@ -161,7 +161,7 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
)
@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
device = "cpu"
dtype = torch.float32
max_batch_size = 1
batch_size = 1
config = AutoConfig.from_pretrained(
"google/gemma-2b",
@@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase):
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
)
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):

View File

@@ -75,6 +75,11 @@ OBJECTS_TO_IGNORE = [
"TFSequenceSummary",
"TFBertTokenizer",
"TFGPT2Tokenizer",
# Going through an argument deprecation cycle, remove after v4.46
"HybridCache",
"MambaCache",
"SlidingWindowCache",
"StaticCache",
# Missing arguments in the docstring
"ASTFeatureExtractor",
"AlbertModel",