Cache: Static cache as a standalone object (#30476)

This commit is contained in:
Joao Gante
2024-04-30 16:37:19 +01:00
committed by GitHub
parent 0ae789e043
commit 75bbfd5b22
20 changed files with 377 additions and 424 deletions

View File

@@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase):
"""
# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position):
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
@@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase):
seq_length = input_ids.shape[1]
# Setup static KV cache for generation
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1)
past_key_values = StaticCache(
config=self.quantized_model.config,
max_batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
)
# Allocate token ids to be generated and copy prefix ids
cache_position = torch.arange(seq_length, device=torch_device)
@@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase):
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
logits = self.quantized_model(
input_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True,
)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token
@@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase):
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position)
next_token = decode_one_tokens(
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1