Cache: Static cache as a standalone object (#30476)
This commit is contained in:
@@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase):
|
||||
"""
|
||||
|
||||
# Sample tokens greedily
|
||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
||||
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
|
||||
logits = model(
|
||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
||||
cur_token,
|
||||
position_ids=input_pos,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
||||
|
||||
@@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase):
|
||||
seq_length = input_ids.shape[1]
|
||||
|
||||
# Setup static KV cache for generation
|
||||
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1)
|
||||
past_key_values = StaticCache(
|
||||
config=self.quantized_model.config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=seq_length + self.max_new_tokens + 1,
|
||||
device=torch_device,
|
||||
dtype=self.quantized_model.config._pre_quantization_dtype,
|
||||
)
|
||||
|
||||
# Allocate token ids to be generated and copy prefix ids
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
@@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase):
|
||||
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
|
||||
|
||||
# Do a forward pass to fill the prefix cache and compile the kernels if necessary
|
||||
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
||||
logits = self.quantized_model(
|
||||
input_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
|
||||
generated_ids[:, [seq_length]] = next_token
|
||||
|
||||
@@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase):
|
||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||
for _ in range(1, self.max_new_tokens):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position)
|
||||
next_token = decode_one_tokens(
|
||||
self.quantized_model, next_token.clone(), None, cache_position, past_key_values
|
||||
)
|
||||
generated_ids.index_copy_(1, cache_position, next_token)
|
||||
cache_position += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user