[Cache] rename dtype attribute 🚨 🚨 (#37044)
* yoink * same pattern in all cache
This commit is contained in:
@@ -1199,7 +1199,7 @@ class StaticCache(Cache):
|
|||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self._dtype = dtype
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads
|
config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
@@ -1216,8 +1216,8 @@ class StaticCache(Cache):
|
|||||||
layer_device = layer_device_map[idx]
|
layer_device = layer_device_map[idx]
|
||||||
else:
|
else:
|
||||||
layer_device = device
|
layer_device = device
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||||
# preventing compiled graph breaks when updating the cache.
|
# preventing compiled graph breaks when updating the cache.
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
@@ -1680,7 +1680,7 @@ class HybridCache(Cache):
|
|||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self._dtype = dtype
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
@@ -1707,8 +1707,8 @@ class HybridCache(Cache):
|
|||||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache.
|
# breaks when updating the cache.
|
||||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(new_layer_key_cache)
|
||||||
@@ -1853,8 +1853,8 @@ class MambaCache:
|
|||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
device: Union[torch.device, str, None] = None,
|
device: Union[torch.device, str, None] = None,
|
||||||
):
|
):
|
||||||
self.dtype = dtype
|
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
|
self._dtype = dtype
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.ssm_state_size = config.state_size
|
self.ssm_state_size = config.state_size
|
||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
@@ -1868,14 +1868,14 @@ class MambaCache:
|
|||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.conv_kernel_size,
|
self.conv_kernel_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=self._dtype,
|
||||||
)
|
)
|
||||||
ssm_state: torch.Tensor = torch.zeros(
|
ssm_state: torch.Tensor = torch.zeros(
|
||||||
self.max_batch_size,
|
self.max_batch_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.ssm_state_size,
|
self.ssm_state_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=self._dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch._dynamo.mark_static_address(conv_state)
|
torch._dynamo.mark_static_address(conv_state)
|
||||||
@@ -1972,7 +1972,7 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
||||||
self.offload_device = torch.device(offload_device)
|
self.offload_device = torch.device(offload_device)
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
self._dtype = dtype if dtype is not None else torch.float32
|
||||||
|
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
@@ -2144,8 +2144,8 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
|
|
||||||
is_cpu_device = device == torch.device("cpu")
|
is_cpu_device = device == torch.device("cpu")
|
||||||
|
|
||||||
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
|
||||||
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
|
value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
|
||||||
|
|
||||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||||
# preventing compiled graph breaks when updating the cache.
|
# preventing compiled graph breaks when updating the cache.
|
||||||
|
|||||||
Reference in New Issue
Block a user