Generate: SinkCache can handle iterative prompts (#27907)
This commit is contained in:
@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||
|
||||
def test_sink_cache_iterative_prompts(self):
|
||||
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
prompt = (
|
||||
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
|
||||
"and must-see attractions."
|
||||
)
|
||||
|
||||
# Prepare generation settings
|
||||
cache = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
|
||||
for _ in range(3):
|
||||
# Tokenize the prompt with the correct chat template
|
||||
chat = [{"role": "user", "content": prompt}]
|
||||
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)
|
||||
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
|
||||
|
||||
# Perform the generation
|
||||
gen_out = model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
|
||||
)
|
||||
input_ids = gen_out
|
||||
|
||||
# We went well beyond the cache length
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
||||
|
||||
# And it still produces a coherent english
|
||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||
last_output = (
|
||||
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
|
||||
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
|
||||
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
|
||||
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
|
||||
"was visiting the historic district of Honolulu. Here,"
|
||||
)
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
Reference in New Issue
Block a user