Make cache traceable (#35873)

simply make cache traceable
This commit is contained in:
Ilyas Moutawwakil
2025-02-20 09:59:25 +01:00
committed by GitHub
parent 31bb662db1
commit 5e2183f344
3 changed files with 21 additions and 30 deletions

View File

@@ -9,12 +9,7 @@ import torch
from packaging import version from packaging import version
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import ( from .utils import is_hqq_available, is_optimum_quanto_available, logging
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils.deprecation import deprecate_kwarg from .utils.deprecation import deprecate_kwarg
@@ -24,7 +19,7 @@ if is_hqq_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Cache(torch.nn.Module): class Cache:
""" """
Base, abstract class for all caches. The actual data structure is specific to each subclass. Base, abstract class for all caches. The actual data structure is specific to each subclass.
""" """
@@ -1140,16 +1135,8 @@ class StaticCache(Cache):
layer_device = self.device layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes: # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # preventing compiled graph breaks when updating the cache.
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)

View File

@@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
if is_torch_available(): if is_torch_available():
from transformers import ( from transformers import PreTrainedModel, StaticCache
PreTrainedModel,
StaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
@@ -72,9 +69,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
config=self.model.config, config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size, batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len, max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device, device=self.model.generation_config.cache_config.device,
dtype=self.model.dtype,
) )
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal: if self.is_causal:
causal_mask = torch.tril( causal_mask = torch.tril(
@@ -109,12 +110,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
""" """
_, seqlen = input_ids.shape _, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache
outs = self.model( outs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attn_mask, attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0), position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
past_key_values=self.static_cache, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
return outs.logits return outs.logits
@@ -143,7 +147,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
prompt_token_len = prompt_token_ids.shape[-1] prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers(): for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"): if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2] max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len) max_generation_length = min(max_generation_length, max_cache_len)
break break

View File

@@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
# Check if the exported model is configured with the `StaticCache` correctly # Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0 n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers(): for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"): if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len) self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1 n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"): if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len) self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1 n_static_value_caches = n_static_value_caches + 1
@@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip ] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT) self.assertEqual(responses, EXPECTED_DECODED_TEXT)