From c4161238bd8f67a2d80715de7d4ce45541955693 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Mar 2025 10:13:29 +0000 Subject: [PATCH] [Cache] Don't initialize the cache on `meta` device (#36543) --- src/transformers/cache_utils.py | 158 ++++++------------ .../generation/configuration_utils.py | 2 +- src/transformers/generation/utils.py | 38 ++++- .../models/cohere2/modeling_cohere2.py | 2 + .../models/cohere2/modular_cohere2.py | 2 + .../models/gemma2/modeling_gemma2.py | 2 + .../models/gemma2/modular_gemma2.py | 2 + tests/generation/test_utils.py | 39 ----- tests/utils/test_cache_utils.py | 40 +++++ 9 files changed, 138 insertions(+), 147 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 94a68bf0df..11c25b2827 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -10,7 +10,6 @@ from packaging import version from .configuration_utils import PretrainedConfig from .utils import is_hqq_available, is_optimum_quanto_available, logging -from .utils.deprecation import deprecate_kwarg if is_hqq_available(): @@ -1064,18 +1063,19 @@ class StaticCache(Cache): The configuration file defining the shape-related attributes required to initialize the static cache. batch_size (`int`): The batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the + number of beams if you are running beam search max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): - The device on which the cache should be initialized. Should be the same as the layer. - The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` - device by default, and then moved to input device when updating. + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): - Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus. - You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. Example: @@ -1101,7 +1101,6 @@ class StaticCache(Cache): is_compileable = True # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. - @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1128,7 +1127,6 @@ class StaticCache(Cache): ) self.dtype = dtype - self.device = torch.device(device) if device is not None else torch.device("meta") self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1139,11 +1137,12 @@ class StaticCache(Cache): self.value_cache: List[torch.Tensor] = [] # Note: There will be significant perf decrease if switching to use 5D tensors instead. cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) + device = torch.device(device) if device is not None else None for idx in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[idx] else: - layer_device = self.device + layer_device = device new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, @@ -1178,12 +1177,7 @@ class StaticCache(Cache): Return: A tuple containing the updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") - if self.key_cache[layer_idx].device.type == "meta": - self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) - self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) - k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) @@ -1211,8 +1205,6 @@ class StaticCache(Cache): # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` - if self.key_cache[layer_idx].device.type == "meta": - return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_cache_shape(self) -> Optional[int]: @@ -1221,10 +1213,9 @@ class StaticCache(Cache): def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].device.type != "meta": - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1261,14 +1252,14 @@ class SlidingWindowCache(StaticCache): max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): - The device on which the cache should be initialized. Should be the same as the layer. - The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` - device by default, and then moved to input device when updating. + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): - Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus. - You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. Example: @@ -1329,11 +1320,6 @@ class SlidingWindowCache(StaticCache): cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") - - if self.key_cache[layer_idx].device.type == "meta": - self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) - self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) - k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) @@ -1380,10 +1366,9 @@ class SlidingWindowCache(StaticCache): def reset(self): for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].device.type != "meta": - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() class EncoderDecoderCache(Cache): @@ -1573,14 +1558,14 @@ class HybridCache(Cache): max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` - device by default, and then moved to input device when updating. + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): - Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus. - You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. Example: @@ -1607,7 +1592,6 @@ class HybridCache(Cache): # is_compileable = True # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. - @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1642,7 +1626,6 @@ class HybridCache(Cache): config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool @@ -1656,11 +1639,12 @@ class HybridCache(Cache): min(config.sliding_window, max_cache_len), self.head_dim, ) + device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[i] else: - layer_device = self.device + layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape @@ -1717,9 +1701,12 @@ class HybridCache(Cache): cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") - if self.key_cache[layer_idx].device.type == "meta": - self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) - self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # when the cache is initialized in the forward pass (e.g. Gemma2) + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + if self.value_cache[layer_idx].device != value_states.device: + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1753,18 +1740,14 @@ class HybridCache(Cache): "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "Using the `layer_idx` argument is not supported." ) - - if self.key_cache[layer_idx].device.type == "meta": - return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].device.type != "meta": - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1789,24 +1772,6 @@ class MambaCache: The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. - The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` - device by default, and then moved to input device when updating. - - Attributes: - dtype: (`torch.dtype`): - The default `dtype` used to initializing the cache. - device (`torch.device`): - The default device on which the cache was initialized. - intermediate_size: (`int`): - Model's intermediate_size taken from config. - ssm_state_size: (`int`): - Model's state_size taken from config. - conv_kernel_size: (`int`): - Model's convolution kernel size taken from config - conv_states: (`torch.Tensor`): - A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. - ssm_states: (`torch.Tensor`): - A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states Example: @@ -1829,6 +1794,7 @@ class MambaCache: is_compileable = True # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + # TODO (joao): add layer_device_map arg and update code in `generate` accordingly def __init__( self, config: PretrainedConfig, @@ -1847,23 +1813,23 @@ class MambaCache: self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel - self.device = torch.device(device) if device is not None else torch.device("meta") self.conv_states: List[torch.Tensor] = [] self.ssm_states: List[torch.Tensor] = [] + device = torch.device(device) if device is not None else None for _ in range(config.num_hidden_layers): conv_state: torch.Tensor = torch.zeros( self.max_batch_size, self.intermediate_size, self.conv_kernel_size, - device=self.device, + device=device, dtype=dtype, ) ssm_state: torch.Tensor = torch.zeros( self.max_batch_size, self.intermediate_size, self.ssm_state_size, - device=self.device, + device=device, dtype=dtype, ) @@ -1875,11 +1841,10 @@ class MambaCache: def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor ) -> torch.Tensor: - if self.conv_states[layer_idx].device.type == "meta": - self.conv_states[layer_idx] = torch.zeros_like( - self.conv_states[layer_idx], - device=new_conv_state.device, - ) + # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used + # when the cache is initialized in the forward pass (e.g. Mamba) + if self.conv_states[layer_idx].device != new_conv_state.device: + self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) @@ -1896,10 +1861,9 @@ class MambaCache: def reset(self): for layer_idx in range(len(self.conv_states)): - if self.conv_states[layer_idx].device.type != "meta": - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() @property def batch_size(self): @@ -1924,33 +1888,16 @@ class OffloadedStaticCache(StaticCache): max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`Union[str, torch.device]`): - The device on which the cache should be initialized. Should be the same as the - layer device. + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. dtype (`torch.dtype`, *optional*): The default `dtype` to use when initializing the cache. offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): The device to offload to. Defaults to CPU. layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus. - You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. - - Attributes: - key_cache (`List[torch.Tensor]`): - Off-loaded key cache tensors. First one will be on device, where-as the others are - off-loaded. - value_cache (`List[torch.Tensor]`): - Off-loaded value cache tensors. First one will be on device, where-as the others are - off-loaded. - max_batch_size (`int`): - The maximum batch size with which this cache can be used. - max_cache_len (`int`): - The maximum sequence length with which this cache can be used. - device (`torch.device`): - The device on which the cache is used. - offload_device (`torch.device`): - The device used to offload to. - dtype (`torch.dtype`): - The `dtype` used to initializing the cache. + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is splitted between differents gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. Example: @@ -1973,7 +1920,6 @@ class OffloadedStaticCache(StaticCache): is_compileable = True - @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 50f7b29f75..6ee48ab3f1 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -483,7 +483,7 @@ class GenerationConfig(PushToHubMixin): self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10) - # Performances + # Performance self.compile_config = kwargs.pop("compile_config", CompileConfig()) self.disable_compile = kwargs.pop("disable_compile", False) # Wild card diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5c7fffb117..8d5be7d7a0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1618,6 +1618,40 @@ class GenerationMixin: model_kwargs["cache_position"] = cache_position return model_kwargs + def _get_layer_device_map_for_cache_init(self): + """ + Taken from `dispatch_model` from accelerate. + This is needed here if we don't want to make changes in accelerate in order to save execution_device + For offloaded case, we need to get the execution device, not just the device where it is offloaded + """ + execution_device_map = None + + if hasattr(self, "hf_device_map"): + if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: + main_device = "cpu" + else: + main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] + execution_device_map = { + name: main_device if device in ["cpu", "disk"] else device + for name, device in self.hf_device_map.items() + } + + num_hidden_layers = self.config.get_text_config().num_hidden_layers + if execution_device_map is None: + return None + elif len(execution_device_map) == 1 and "" in execution_device_map: + return {idx: execution_device_map[""] for idx in range(num_hidden_layers)} + layer_device_map = {} + for layer in execution_device_map: + for idx in range(num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + for idx in range(num_hidden_layers): + if idx not in layer_device_map: + raise RuntimeError(f"layer {idx} has not been mapped to a device.") + return layer_device_map + def _get_cache( self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs ) -> Cache: @@ -1664,12 +1698,14 @@ class GenerationMixin: # models. May cause trobles with non-text modalities. cache_dtype = self.get_output_embeddings().weight.dtype + layer_device_map = self._get_layer_device_map_for_cache_init() cache_kwargs = { "config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, "dtype": cache_dtype, - "device": device if cache_implementation == "offloaded_static" else None, + "device": device, + "layer_device_map": layer_device_map, } self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index afd2125d09..1e64518299 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -597,11 +597,13 @@ class Cohere2Model(Cohere2PreTrainedModel): if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape + # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, max_batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, + device=self.device, ) if cache_position is None: diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 979b5abc26..2a2b1c88c3 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -488,11 +488,13 @@ class Cohere2Model(Gemma2Model): if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape + # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, max_batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, + device=self.device, ) if cache_position is None: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 9db0b8368b..ba16f52c35 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -599,11 +599,13 @@ class Gemma2Model(Gemma2PreTrainedModel): if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape + # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, max_batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, + device=self.device, ) if cache_position is None: diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 0f32c00287..ab567c61d0 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -437,11 +437,13 @@ class Gemma2Model(GemmaModel): if use_cache and past_key_values is None and not self.training: batch_size, seq_len, _ = inputs_embeds.shape + # NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map` past_key_values = HybridCache( self.config, max_batch_size=batch_size, max_cache_len=seq_len, dtype=inputs_embeds.dtype, + device=self.device, ) if cache_position is None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6d93c77d86..df58c7fc5c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2304,45 +2304,6 @@ class GenerationTesterMixin: without_all_logits = model.generate(**inputs_dict, **generation_kwargs) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) - @pytest.mark.generate - @is_flaky - def test_assisted_decoding_with_logits_to_keep(self): - for model_class in self.all_generative_model_classes: - if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `logits_to_keep` argument.") - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") - - config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - config.use_cache = True - config.is_decoder = True - - model = model_class(config).to(torch_device).eval() - assistant_model = model - # All generation methods (except assisted decoding) rely on always extracting the last token logits of the - # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, - # other methods will work as well) - generation_kwargs = { - "max_new_tokens": 10, - "do_sample": False, - "assistant_model": assistant_model, - "return_dict_in_generate": True, - "output_scores": True, - } - logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) - - # Setting logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate( - **generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0 - ) - # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) - without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs) - - self._check_similar_generate_outputs(with_all_logits, without_all_logits) - @pytest.mark.generate def test_inherits_generation_mixin(self): """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index e16d30e549..fc7617e649 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import set_seed from transformers.testing_utils import ( + CaptureStderr, get_gpu_count, is_torch_available, require_gptq, @@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase): torch.testing.assert_close( actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx] ) + + @require_torch_gpu + def test_static_cache_no_cuda_graph_skips(self): + """ + Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543. + + (? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead, + messages are being thrown to stderr?) + """ + model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_repo) + inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device) + + # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped. + with CaptureStderr() as cap: + model.generate(**inputs, max_new_tokens=2, cache_implementation="static") + self.assertEqual(cap.err, "") + + @require_torch_multi_gpu + def test_static_cache_multi_gpu(self): + """Regression test for #35164: static cache with multi-gpu""" + + model_id = "google/gemma-2-2b-it" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0} + num_hidden_layers = 26 + for i in range(num_hidden_layers): + device_map[f"model.layers.{i}"] = 0 if i < 13 else 1 + + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="bfloat16", + device_map=device_map, + ) + inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0) + _ = model(**inputs) + _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")