[Cache] rename dtype attribute 🚨 🚨 (#37044)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* yoink

* same pattern in all cache
This commit is contained in:
Joao Gante
2025-03-28 18:08:02 +00:00
committed by GitHub
parent 9fd9476005
commit bab605dd04

View File

@@ -1199,7 +1199,7 @@ class StaticCache(Cache):
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype self._dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
@@ -1216,8 +1216,8 @@ class StaticCache(Cache):
layer_device = layer_device_map[idx] layer_device = layer_device_map[idx]
else: else:
layer_device = device layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache. # preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
@@ -1680,7 +1680,7 @@ class HybridCache(Cache):
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype self._dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
) )
@@ -1707,8 +1707,8 @@ class HybridCache(Cache):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. # breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
@@ -1853,8 +1853,8 @@ class MambaCache:
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None,
): ):
self.dtype = dtype
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self._dtype = dtype
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
@@ -1868,14 +1868,14 @@ class MambaCache:
self.intermediate_size, self.intermediate_size,
self.conv_kernel_size, self.conv_kernel_size,
device=device, device=device,
dtype=dtype, dtype=self._dtype,
) )
ssm_state: torch.Tensor = torch.zeros( ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size, self.max_batch_size,
self.intermediate_size, self.intermediate_size,
self.ssm_state_size, self.ssm_state_size,
device=device, device=device,
dtype=dtype, dtype=self._dtype,
) )
torch._dynamo.mark_static_address(conv_state) torch._dynamo.mark_static_address(conv_state)
@@ -1972,7 +1972,7 @@ class OffloadedStaticCache(StaticCache):
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device) self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32 self._dtype = dtype if dtype is not None else torch.float32
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -2144,8 +2144,8 @@ class OffloadedStaticCache(StaticCache):
is_cpu_device = device == torch.device("cpu") is_cpu_device = device == torch.device("cpu")
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device) value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache. # preventing compiled graph breaks when updating the cache.