Refactor MambaCache to modeling_mamba.py (#38086)

* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
Manuel de Prada Corral
2025-07-21 14:59:36 +02:00
committed by GitHub
parent a419a40234
commit 1aa7256f01
16 changed files with 1033 additions and 307 deletions

View File

@@ -110,6 +110,13 @@ outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## FalconMambaCache
[[autodoc]] FalconMambaCache
- update_conv_state
- update_ssm_state
- reset
## FalconMambaConfig
[[autodoc]] FalconMambaConfig

View File

@@ -116,6 +116,13 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
trainer.train()
```
## MambaCache
[[autodoc]] MambaCache
- update_conv_state
- update_ssm_state
- reset
## MambaConfig
[[autodoc]] MambaConfig