Fixed Hybrid Cache Shape Initialization. (#32163)
* fixed hybrid cache init, added test * Fix Test Typo --------- Co-authored-by: Aaron Haag <aaron.haag@siemens.com>
This commit is contained in:
@@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
def test_hybrid_cache_n_sequences(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-9b",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
|
||||
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@require_auto_gptq
|
||||
def test_sink_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||
|
||||
Reference in New Issue
Block a user