committed by
GitHub
parent
31bb662db1
commit
5e2183f344
@@ -9,12 +9,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import (
|
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||||
is_hqq_available,
|
|
||||||
is_optimum_quanto_available,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from .utils.deprecation import deprecate_kwarg
|
from .utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +19,7 @@ if is_hqq_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Cache(torch.nn.Module):
|
class Cache:
|
||||||
"""
|
"""
|
||||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||||
"""
|
"""
|
||||||
@@ -1140,18 +1135,10 @@ class StaticCache(Cache):
|
|||||||
layer_device = self.device
|
layer_device = self.device
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
# Notes:
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# preventing compiled graph breaks when updating the cache.
|
||||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
# it is not needed anyway)
|
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
|
||||||
if not is_torchdynamo_compiling():
|
|
||||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
|
||||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
|
||||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
|
||||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
||||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(new_layer_key_cache)
|
||||||
self.value_cache.append(new_layer_value_cache)
|
self.value_cache.append(new_layer_value_cache)
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import (
|
from transformers import PreTrainedModel, StaticCache
|
||||||
PreTrainedModel,
|
|
||||||
StaticCache,
|
|
||||||
)
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||||
|
|
||||||
|
|
||||||
@@ -72,9 +69,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||||
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
||||||
dtype=self.model.dtype,
|
|
||||||
device=self.model.generation_config.cache_config.device,
|
device=self.model.generation_config.cache_config.device,
|
||||||
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
|
for i in range(len(self.static_cache.key_cache)):
|
||||||
|
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||||
|
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||||
|
|
||||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||||
if self.is_causal:
|
if self.is_causal:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
@@ -109,12 +110,15 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
_, seqlen = input_ids.shape
|
_, seqlen = input_ids.shape
|
||||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
past_key_values = self.static_cache
|
||||||
|
|
||||||
outs = self.model(
|
outs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask,
|
||||||
position_ids=cache_position.unsqueeze(0),
|
position_ids=position_ids,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
past_key_values=self.static_cache,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
return outs.logits
|
return outs.logits
|
||||||
@@ -143,7 +147,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
prompt_token_len = prompt_token_ids.shape[-1]
|
prompt_token_len = prompt_token_ids.shape[-1]
|
||||||
max_generation_length = prompt_token_len + max_new_tokens
|
max_generation_length = prompt_token_len + max_new_tokens
|
||||||
for buffer_name, buffer in exported_program.named_buffers():
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
if buffer_name.startswith("static_cache.key_cache"):
|
if buffer_name.startswith("key_cache"):
|
||||||
max_cache_len = buffer.shape[2]
|
max_cache_len = buffer.shape[2]
|
||||||
max_generation_length = min(max_generation_length, max_cache_len)
|
max_generation_length = min(max_generation_length, max_cache_len)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
|
|||||||
# Check if the exported model is configured with the `StaticCache` correctly
|
# Check if the exported model is configured with the `StaticCache` correctly
|
||||||
n_static_key_caches = n_static_value_caches = 0
|
n_static_key_caches = n_static_value_caches = 0
|
||||||
for buffer_name, buffer in exported_program.named_buffers():
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
if buffer_name.startswith("static_cache.key_cache"):
|
if buffer_name.startswith("key_cache"):
|
||||||
self.assertTrue(buffer.shape[0] == batch_size)
|
self.assertTrue(buffer.shape[0] == batch_size)
|
||||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
n_static_key_caches = n_static_key_caches + 1
|
n_static_key_caches = n_static_key_caches + 1
|
||||||
if buffer_name.startswith("static_cache.value_cache"):
|
if buffer_name.startswith("value_cache"):
|
||||||
self.assertTrue(buffer.shape[0] == batch_size)
|
self.assertTrue(buffer.shape[0] == batch_size)
|
||||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
n_static_value_caches = n_static_value_caches + 1
|
n_static_value_caches = n_static_value_caches + 1
|
||||||
@@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
|
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||||
|
|||||||
Reference in New Issue
Block a user