Offloaded KV Cache (#31325)
* Initial implementation of OffloadedCache * enable usage via cache_implementation * Address feedback, add tests, remove legacy methods. * Remove flash-attn, discover synchronization bugs, fix bugs * Prevent usage in CPU only mode * Add a section about offloaded KV cache to the docs * Fix typos in docs * Clarifications and better explanation of streams
This commit is contained in:
committed by
GitHub
parent
b4727a1216
commit
ca59d6f77c
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
GenerationConfig,
|
||||
GPT2LMHeadModel,
|
||||
LlamaConfig,
|
||||
SinkCache,
|
||||
@@ -513,3 +514,54 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
|
||||
def test_static_cache_beam_search(self):
|
||||
pass
|
||||
|
||||
@require_torch_gpu
|
||||
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
|
||||
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
|
||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||
device = model.device
|
||||
input_text = "Fun fact:"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||
common = {
|
||||
"num_beams": 4,
|
||||
"num_beam_groups": 2,
|
||||
"num_return_sequences": 4,
|
||||
"diversity_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"early_stopping": True,
|
||||
}
|
||||
original = GenerationConfig(**common)
|
||||
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||
original_outputs = model.generate(generation_config=original, **inputs)
|
||||
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
|
||||
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
|
||||
assert torch.all(original_output == offloaded_output).item()
|
||||
|
||||
@require_torch_gpu
|
||||
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
||||
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
|
||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||
device = model.device
|
||||
input_text = "Fun fact:"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||
common = {
|
||||
"num_beams": 4,
|
||||
"num_beam_groups": 2,
|
||||
"num_return_sequences": 4,
|
||||
"diversity_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"early_stopping": True,
|
||||
}
|
||||
original = GenerationConfig(**common)
|
||||
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
model.generate(generation_config=original, **inputs)
|
||||
original_peak_memory = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
model.generate(generation_config=offloaded, **inputs)
|
||||
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
|
||||
assert offloaded_peak_memory < original_peak_memory
|
||||
|
||||
Reference in New Issue
Block a user