From cf32ee1753c9747b877113a309c2aa989f6d006c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 16 Aug 2024 11:48:45 +0100 Subject: [PATCH] Cache: use `batch_size` instead of `max_batch_size` (#32657) * more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/llm_optims.md | 4 +- src/transformers/cache_utils.py | 127 +++++++++++++----- src/transformers/generation/utils.py | 8 +- .../models/gemma2/modeling_gemma2.py | 2 +- tests/models/llama/test_modeling_llama.py | 4 +- tests/models/phi3/test_modeling_phi3.py | 4 +- .../aqlm_integration/test_aqlm.py | 2 +- tests/utils/test_cache_utils.py | 10 +- utils/check_docstrings.py | 5 + 9 files changed, 112 insertions(+), 54 deletions(-) diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 8e7e9c54d4..881cd6cd75 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16 past_key_values = StaticCache( config=model.config, - max_batch_size=1, + batch_size=1, # If you plan to reuse the cache, make sure the cache length is large enough for all cases max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2), device=model.device, @@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): past_key_values = StaticCache( - config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype ) cache_position = torch.arange(seq_length, device=torch_device) generated_ids = torch.zeros( diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bb9c4565d1..56eb0c4080 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -977,13 +977,14 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device`): + device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. Example: @@ -999,22 +1000,37 @@ class StaticCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` """ - def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: torch.device = None, + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + ) -> None: super().__init__() - self.max_batch_size = max_batch_size + if max_batch_size is not None: + logger.warning_once( + f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.46. Use the more precisely named 'batch_size' argument instead." + ) + + self.batch_size = batch_size or max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype if dtype is not None else torch.float32 + self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1024,7 +1040,7 @@ class StaticCache(Cache): self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for idx in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) @@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device`): + device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. Example: @@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` """ - def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: torch.device = None, + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + ) -> None: super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( @@ -1168,7 +1194,12 @@ class SlidingWindowCache(StaticCache): ) max_cache_len = min(config.sliding_window, max_cache_len) super().__init__( - config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype + config=config, + batch_size=batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + max_batch_size=max_batch_size, ) def update( @@ -1407,13 +1438,14 @@ class HybridCache(Cache): Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device`, *optional*, defaults to `"cpu"`): + device (`torch.device` or `str`, *optional*, defaults to `"cpu"`): The device on which the cache should be initialized. Should be the same as the layer. - dtype (*optional*, defaults to `torch.float32`): + dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. Example: @@ -1429,14 +1461,28 @@ class HybridCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` """ - def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + max_cache_len: int = None, + device: Union[torch.device, str] = "cpu", + dtype: torch.dtype = torch.float32, + max_batch_size: Optional[int] = None, + ) -> None: super().__init__() + if max_batch_size is not None: + logger.warning_once( + f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.46. Use the more precisely named 'batch_size' argument instead." + ) if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " @@ -1444,13 +1490,13 @@ class HybridCache(Cache): "config and it's not set to None." ) self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size + self.batch_size = batch_size or max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype if dtype is not None else torch.float32 + self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) @@ -1459,9 +1505,9 @@ class HybridCache(Cache): ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) + global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) sliding_cache_shape = ( - max_batch_size, + self.batch_size, self.num_key_value_heads, min(config.sliding_window, max_cache_len), self.head_dim, @@ -1564,11 +1610,12 @@ class MambaCache: Arguments: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - dtype (*optional*, defaults to `torch.float16`): + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): The default `dtype` to use when initializing the layer. - device (`torch.device`, *optional*): + device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. Attributes: @@ -1596,29 +1643,35 @@ class MambaCache: >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv = outputs.past_key_values ``` """ + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. def __init__( self, config: PretrainedConfig, - max_batch_size: int, + batch_size: int = None, dtype: torch.dtype = torch.float16, - device: Optional[str] = None, - **kwargs, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, ): + if max_batch_size is not None: + logger.warning_once( + f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.46. Use the more precisely named 'batch_size' argument instead." + ) self.dtype = dtype - self.max_batch_size = max_batch_size + self.batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel self.conv_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.max_batch_size, + self.batch_size, self.intermediate_size, self.conv_kernel_size, device=device, @@ -1626,7 +1679,7 @@ class MambaCache: ) self.ssm_states: torch.Tensor = torch.zeros( config.num_hidden_layers, - self.max_batch_size, + self.batch_size, self.intermediate_size, self.ssm_state_size, device=device, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 24c9e3bb18..62a46330ac 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1426,7 +1426,7 @@ class GenerationMixin: return model_kwargs def _get_cache( - self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs + self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs ) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a @@ -1448,7 +1448,7 @@ class GenerationMixin: need_new_cache = ( not hasattr(self, "_cache") or (not isinstance(cache_to_check, cache_cls)) - or cache_to_check.max_batch_size != max_batch_size + or cache_to_check.batch_size != batch_size ) if cache_implementation != "mamba": need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len @@ -1473,7 +1473,7 @@ class GenerationMixin: cache_kwargs = { "config": self.config, - "max_batch_size": max_batch_size, + "batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, @@ -1812,7 +1812,7 @@ class GenerationMixin: ) model_kwargs[cache_name] = self._get_cache( cache_implementation=generation_config.cache_implementation, - max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, max_cache_len=generation_config.max_length, device=device, model_kwargs=model_kwargs, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ee5af616ec..5ae357a527 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -818,7 +818,7 @@ class Gemma2Model(Gemma2PreTrainedModel): batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, - max_batch_size=batch_size, + batch_size=batch_size, max_cache_len=seq_len, device=self.device, dtype=inputs_embeds.dtype, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index a32fa3437e..ed2aa922fe 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -1040,7 +1040,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, @@ -1088,7 +1088,7 @@ class Mask4DTestHard(unittest.TestCase): max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1] past_key_values = StaticCache( config=self.model.config, - max_batch_size=1, + batch_size=1, max_cache_len=max_cache_len, device=torch_device, dtype=self.model.dtype, diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index eb113a4df6..a3f001aba4 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -47,12 +47,12 @@ if is_torch_available(): end_of_text_token = 32000 class Phi3MiniWithStaticCache(torch.nn.Module): - def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int): + def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int): super().__init__() self.model = model self.cache = StaticCache( config=model.config, - max_batch_size=max_batch_size, + batch_size=batch_size, max_cache_len=max_seq_len, device=self.model.device, dtype=self.model.dtype, diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 3b0dd99adc..b79eae54c0 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -216,7 +216,7 @@ class AqlmTest(unittest.TestCase): # Setup static KV cache for generation past_key_values = StaticCache( config=self.quantized_model.config, - max_batch_size=1, + batch_size=1, max_cache_len=seq_length + self.max_new_tokens + 1, device=torch_device, dtype=self.quantized_model.config._pre_quantization_dtype, diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 5b61999b56..4a9acf4a27 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -145,7 +145,7 @@ class CacheTest(unittest.TestCase): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -153,7 +153,7 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -161,7 +161,7 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase): device = "cpu" dtype = torch.float32 - max_batch_size = 1 + batch_size = 1 config = AutoConfig.from_pretrained( "google/gemma-2b", @@ -203,7 +203,7 @@ class CacheTest(unittest.TestCase): self.config = config self.model = model self.static_cache = StaticCache( - config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device + config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device ) def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 928bd332d2..aff2dd6a43 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -75,6 +75,11 @@ OBJECTS_TO_IGNORE = [ "TFSequenceSummary", "TFBertTokenizer", "TFGPT2Tokenizer", + # Going through an argument deprecation cycle, remove after v4.46 + "HybridCache", + "MambaCache", + "SlidingWindowCache", + "StaticCache", # Missing arguments in the docstring "ASTFeatureExtractor", "AlbertModel",