[tests] Test all cache implementations (#37873)
This commit is contained in:
@@ -18,13 +18,12 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
cleanup,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
require_non_xpu,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@@ -32,6 +31,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -40,15 +40,24 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Cache,
|
||||
ClvpForCausalLM,
|
||||
DynamicCache,
|
||||
GenerationConfig,
|
||||
LlamaConfig,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
TEST_CACHE_IMPLEMENTATIONS = [
|
||||
cache_name
|
||||
for cache_name in ALL_CACHE_IMPLEMENTATIONS
|
||||
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
|
||||
if cache_name != "mamba"
|
||||
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
|
||||
if cache_name != "offloaded_hybrid"
|
||||
]
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -176,9 +185,121 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class CacheIntegrationTest(unittest.TestCase):
|
||||
"""Cache tests that require loading models"""
|
||||
"""Fast cache integration tests that share the same small model"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Load once and reuse across tests
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left")
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
||||
|
||||
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
|
||||
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
|
||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
if "offloaded" in cache_implementation:
|
||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||
if not has_accelerator:
|
||||
self.skipTest("Offloaded caches require an accelerator")
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_batched(self, cache_implementation):
|
||||
"""Sanity check: caches' `.update` function expects batched inputs"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
gen_out = self.model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=10,
|
||||
return_dict_in_generate=True,
|
||||
cache_implementation=cache_implementation,
|
||||
disable_compile=True,
|
||||
)
|
||||
# Sanity check: a cache was used
|
||||
self.assertIsInstance(gen_out.past_key_values, Cache)
|
||||
# Confirm that the output matches expectations
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_beam_search(self, cache_implementation):
|
||||
"""
|
||||
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
||||
(an output sequence contains multiple beam indices).
|
||||
"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
if cache_implementation == "offloaded_hybrid_chunked":
|
||||
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
||||
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
||||
# different from the other caches.
|
||||
self.skipTest("`offloaded_hybrid_chunked` fails this test")
|
||||
|
||||
EXPECTED_GENERATION = [
|
||||
"Blue is the color of the sky, and the color of",
|
||||
"Blue is the color of the sky, and the second is",
|
||||
]
|
||||
|
||||
inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device)
|
||||
gen_out = self.model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=10,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
cache_implementation=cache_implementation,
|
||||
disable_compile=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
# Sanity check: a cache was used
|
||||
self.assertIsInstance(gen_out.past_key_values, Cache)
|
||||
# At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around
|
||||
self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices))
|
||||
# Confirm that the output matches expectations
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
||||
|
||||
inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device)
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 10,
|
||||
"cache_implementation": cache_implementation,
|
||||
"disable_compile": True,
|
||||
}
|
||||
|
||||
gen_out = self.model.generate(**inputs, **generation_kwargs)
|
||||
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
# Now with extra left-padding
|
||||
inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32)
|
||||
inputs_expanded = inputs_expanded.to(self.model.device)
|
||||
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
||||
gen_out = self.model.generate(**inputs_expanded, **generation_kwargs)
|
||||
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class CacheHardIntegrationTest(unittest.TestCase):
|
||||
"""Hard cache integration tests that require loading different models"""
|
||||
|
||||
def tearDown(self):
|
||||
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests
|
||||
@@ -187,18 +308,15 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_dynamic_cache_hard(self):
|
||||
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||
|
||||
# DynamicCache and the legacy cache format should be equivalent
|
||||
set_seed(0)
|
||||
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = (
|
||||
@@ -215,138 +333,11 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(decoded[0], expected_text)
|
||||
|
||||
@slow
|
||||
def test_dynamic_cache_batched(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
|
||||
model.device
|
||||
)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@slow
|
||||
def test_dynamic_cache_beam_search(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
|
||||
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@slow
|
||||
def test_hybrid_cache_n_sequences(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-9b",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
|
||||
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_return_sequences=2,
|
||||
num_beams=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@require_non_xpu
|
||||
@require_gptq
|
||||
@slow
|
||||
def test_sink_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
|
||||
|
||||
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
|
||||
|
||||
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
|
||||
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
|
||||
cache = SinkCache(window_length=508, num_sink_tokens=4)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||
|
||||
@slow
|
||||
def test_sink_cache_iterative_prompts(self):
|
||||
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
prompt = (
|
||||
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
|
||||
"and must-see attractions."
|
||||
)
|
||||
|
||||
# Prepare generation settings
|
||||
cache = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
|
||||
for _ in range(3):
|
||||
# Tokenize the prompt with the correct chat template
|
||||
chat = [{"role": "user", "content": prompt}]
|
||||
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)
|
||||
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
|
||||
|
||||
# Perform the generation
|
||||
gen_out = model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
|
||||
)
|
||||
input_ids = gen_out
|
||||
|
||||
# We went well beyond the cache length
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
|
||||
|
||||
# And it still produces a coherent english
|
||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||
last_output = (
|
||||
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
|
||||
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
|
||||
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
|
||||
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
|
||||
"was visiting the historic district of Honolulu. Here,"
|
||||
)
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager", "static"),
|
||||
("sdpa", "static"),
|
||||
]
|
||||
)
|
||||
@parameterized.expand([("eager"), ("sdpa")])
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||
"""Tests that different cache implementations work well with eager and SDPA inference"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
@@ -371,124 +362,19 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = cache_implementation
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.forward = torch.compile(model.forward)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static")
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@slow
|
||||
def test_dynamic_cache_extra_left_padding(self):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
# Now with extra left-padding
|
||||
inputs_expanded = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
pad_to_multiple_of=32,
|
||||
).to(model.device)
|
||||
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
||||
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@slow
|
||||
def test_static_cache_extra_left_padding(self):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
# Now with extra left-padding
|
||||
inputs_expanded = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
pad_to_multiple_of=32,
|
||||
).to(model.device)
|
||||
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
||||
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
|
||||
def test_static_cache_beam_search(self):
|
||||
pass
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
|
||||
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
|
||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
||||
device = model.device
|
||||
|
||||
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
|
||||
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
|
||||
|
||||
input_text = "Fun fact:"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to(device)
|
||||
common = {
|
||||
"num_beams": 4,
|
||||
"num_beam_groups": 2,
|
||||
"num_return_sequences": 4,
|
||||
"diversity_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"early_stopping": True,
|
||||
}
|
||||
original = GenerationConfig(**common)
|
||||
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
|
||||
original_outputs = model.generate(generation_config=original, **inputs)
|
||||
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
|
||||
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
|
||||
assert torch.all(original_output == offloaded_output).item()
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
|
||||
@@ -526,12 +412,14 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
torch_accelerator_module.reset_peak_memory_stats(device)
|
||||
model.generate(generation_config=offloaded, **inputs)
|
||||
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
|
||||
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
|
||||
assert offloaded_peak_memory < original_peak_memory
|
||||
self.assertTrue(offloaded_peak_memory < original_peak_memory)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_cache_copy(self):
|
||||
"""Tests that we can manually set a cache, copy, and reuse it for generation"""
|
||||
# TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the
|
||||
# lazy init of cache layers
|
||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
|
||||
@@ -542,7 +430,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
|
||||
INITIAL_PROMPT = "You are a helpful assistant. "
|
||||
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
|
||||
# This is the common prompt cached, we need to run forward without grad to be abel to copy
|
||||
# This is the common prompt cached, we need to run forward without grad to be able to copy
|
||||
with torch.no_grad():
|
||||
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
|
||||
|
||||
@@ -551,14 +439,19 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
for prompt in prompts:
|
||||
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
||||
past_key_values = copy.deepcopy(prompt_cache)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40)
|
||||
outputs = model.generate(
|
||||
**new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True
|
||||
)
|
||||
response = tokenizer.batch_decode(outputs)[0]
|
||||
responses.append(response)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful "
|
||||
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a "
|
||||
"first-time adventurer, there is always something",
|
||||
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
|
||||
"of France.\n\n\n\n\n\n\n<|endoftext|>",
|
||||
]
|
||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@@ -609,7 +502,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
||||
with CaptureStderr() as cap:
|
||||
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
||||
self.assertEqual(cap.err, "")
|
||||
self.assertNotIn("cuda", cap.err.lower())
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user