Generate: SinkCache can handle iterative prompts (#27907)

This commit is contained in:
Joao Gante
2023-12-08 20:02:20 +00:00
committed by GitHub
parent 94c765380c
commit ce0bbd5101
6 changed files with 116 additions and 34 deletions

View File

@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
def test_sink_cache_iterative_prompts(self):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
)
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)
# Prepare generation settings
cache = SinkCache(window_length=256, num_sink_tokens=4)
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
# Perform the generation
gen_out = model.generate(
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
)
input_ids = gen_out
# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))