[Cache] rename dtype attribute 🚨 🚨 (#37044)
* yoink * same pattern in all cache
This commit is contained in:
@@ -1199,7 +1199,7 @@ class StaticCache(Cache):
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self._dtype = dtype
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
@@ -1216,8 +1216,8 @@ class StaticCache(Cache):
|
||||
layer_device = layer_device_map[idx]
|
||||
else:
|
||||
layer_device = device
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
# preventing compiled graph breaks when updating the cache.
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
@@ -1680,7 +1680,7 @@ class HybridCache(Cache):
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self._dtype = dtype
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
@@ -1707,8 +1707,8 @@ class HybridCache(Cache):
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
@@ -1853,8 +1853,8 @@ class MambaCache:
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_batch_size = max_batch_size
|
||||
self._dtype = dtype
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
@@ -1868,14 +1868,14 @@ class MambaCache:
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
torch._dynamo.mark_static_address(conv_state)
|
||||
@@ -1972,7 +1972,7 @@ class OffloadedStaticCache(StaticCache):
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
self._dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
@@ -2144,8 +2144,8 @@ class OffloadedStaticCache(StaticCache):
|
||||
|
||||
is_cpu_device = device == torch.device("cpu")
|
||||
|
||||
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
||||
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
||||
key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
|
||||
value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
|
||||
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
# preventing compiled graph breaks when updating the cache.
|
||||
|
||||
Reference in New Issue
Block a user