Refactor MambaCache to modeling_mamba.py (#38086)

* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
Manuel de Prada Corral
2025-07-21 14:59:36 +02:00
committed by GitHub
parent a419a40234
commit 1aa7256f01
16 changed files with 1033 additions and 307 deletions

View File

@@ -141,7 +141,12 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return updated_node
def leave_ImportFrom(self, original_node, updated_node):
"""The imports from other file types (configuration, processing etc) should use original model name."""
"""
The imports from other file types (configuration, processing etc) should use original model name.
Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`)
"""
if len(original_node.relative) == 0: # no replaces on absolute imports
return original_node
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
patterns = "|".join(ALL_FILE_TYPES)
regex = rf"({patterns})_{self.new_name}"