Make mamba use cache (#31116)
* make mamba use cache * uss cache naming as in mamba * fix musicgen
This commit is contained in:
committed by
GitHub
parent
f5c0fa9f6f
commit
7729b77478
@@ -447,10 +447,9 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
|
||||
model.to(device)
|
||||
model.config.use_cache = True
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=10)
|
||||
output_sentence = tokenizer.decode(out[0, :])
|
||||
self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user