From bab605dd042906c393f16fef8a357cb8f7fd93d2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 28 Mar 2025 18:08:02 +0000 Subject: [PATCH] =?UTF-8?q?[Cache]=20rename=20dtype=20attribute=20?= =?UTF-8?q?=F0=9F=9A=A8=20=F0=9F=9A=A8=20=20(#37044)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * yoink * same pattern in all cache --- src/transformers/cache_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f8cced3bb3..6fa96e5c8e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1199,7 +1199,7 @@ class StaticCache(Cache): config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype + self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1216,8 +1216,8 @@ class StaticCache(Cache): layer_device = layer_device_map[idx] else: layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # preventing compiled graph breaks when updating the cache. torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1680,7 +1680,7 @@ class HybridCache(Cache): config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype + self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) @@ -1707,8 +1707,8 @@ class HybridCache(Cache): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) @@ -1853,8 +1853,8 @@ class MambaCache: dtype: torch.dtype = torch.float16, device: Union[torch.device, str, None] = None, ): - self.dtype = dtype self.max_batch_size = max_batch_size + self._dtype = dtype self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel @@ -1868,14 +1868,14 @@ class MambaCache: self.intermediate_size, self.conv_kernel_size, device=device, - dtype=dtype, + dtype=self._dtype, ) ssm_state: torch.Tensor = torch.zeros( self.max_batch_size, self.intermediate_size, self.ssm_state_size, device=device, - dtype=dtype, + dtype=self._dtype, ) torch._dynamo.mark_static_address(conv_state) @@ -1972,7 +1972,7 @@ class OffloadedStaticCache(StaticCache): self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) - self.dtype = dtype if dtype is not None else torch.float32 + self._dtype = dtype if dtype is not None else torch.float32 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -2144,8 +2144,8 @@ class OffloadedStaticCache(StaticCache): is_cpu_device = device == torch.device("cpu") - key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) + key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) + value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # preventing compiled graph breaks when updating the cache.