Zamba new attention standard (#35375)

* updated zamba to new attention standard

* make fixup fixes
This commit is contained in:
pglorio
2025-01-06 23:08:45 -10:00
committed by GitHub
parent 12ba96aa3c
commit bd442c6d3a
2 changed files with 104 additions and 292 deletions

View File

@@ -46,7 +46,7 @@ if is_torch_available():
ZambaModel,
)
from transformers.models.zamba.modeling_zamba import (
HybridMambaAttentionDynamicCache,
ZambaHybridDynamicCache,
)
@@ -215,9 +215,7 @@ class ZambaModelTester:
# first forward pass
# Attention: Zamba needs the cache to be initialized to return a cache!
past_key_values = HybridMambaAttentionDynamicCache(
config, input_ids.shape[0], model.dtype, device=model.device
)
past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
outputs = model(
input_ids,
attention_mask=input_mask,