Refactor MambaCache to modeling_mamba.py (#38086)

* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
Manuel de Prada Corral
2025-07-21 14:59:36 +02:00
committed by GitHub
parent a419a40234
commit 1aa7256f01
16 changed files with 1033 additions and 307 deletions

View File

@@ -20,7 +20,7 @@ from unittest.util import safe_repr
from parameterized import parameterized
from transformers import AutoTokenizer, MambaConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -32,10 +32,10 @@ if is_torch_available():
import torch
from transformers import (
MambaCache,
MambaForCausalLM,
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
class MambaModelTester:
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_config(self):
self.config_tester.run_common_tests()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
)
@unittest.skip("Mamba models do not support DDP.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class MambaIntegrationTests(unittest.TestCase):
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
torch_device
)
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)