Zamba new attention standard (#35375)
* updated zamba to new attention standard * make fixup fixes
This commit is contained in:
@@ -46,7 +46,7 @@ if is_torch_available():
|
||||
ZambaModel,
|
||||
)
|
||||
from transformers.models.zamba.modeling_zamba import (
|
||||
HybridMambaAttentionDynamicCache,
|
||||
ZambaHybridDynamicCache,
|
||||
)
|
||||
|
||||
|
||||
@@ -215,9 +215,7 @@ class ZambaModelTester:
|
||||
|
||||
# first forward pass
|
||||
# Attention: Zamba needs the cache to be initialized to return a cache!
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
config, input_ids.shape[0], model.dtype, device=model.device
|
||||
)
|
||||
past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
|
||||
Reference in New Issue
Block a user