Refactor MambaCache to modeling_mamba.py (#38086)

* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
Manuel de Prada Corral
2025-07-21 14:59:36 +02:00
committed by GitHub
parent a419a40234
commit 1aa7256f01
16 changed files with 1033 additions and 307 deletions

View File

@@ -26,7 +26,6 @@ from transformers.testing_utils import (
require_torch_accelerator,
require_torch_large_accelerator,
require_torch_multi_accelerator,
require_torch_multi_gpu,
slow,
torch_device,
)
@@ -41,10 +40,10 @@ if is_torch_available():
import torch
from transformers import (
FalconMambaCache,
FalconMambaForCausalLM,
FalconMambaModel,
)
from transformers.cache_utils import MambaCache
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
@@ -312,31 +311,6 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_config(self):
self.config_tester.run_common_tests()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_falcon_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs)
@@ -411,7 +385,7 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, MambaCache): # MODIFIED PART START
if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START
recursive_check(tuple_object.conv_states, dict_object.conv_states)
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END
@@ -458,6 +432,10 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@unittest.skip("Mamba models do not support DDP.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
@require_torch_accelerator
@@ -497,7 +475,9 @@ class FalconMambaIntegrationTests(unittest.TestCase):
@require_bitsandbytes
def test_generation_4bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to(
torch_device
)
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -513,6 +493,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0])
self.assertEqual(
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
@@ -543,7 +524,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16)
out = model.generate(**inputs, max_new_tokens=20)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)
@@ -553,7 +534,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
inputs["inputs_embeds"] = inputs_embeds
out = model.generate(**inputs, max_new_tokens=20)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
out = tok.batch_decode(out, skip_special_tokens=True)
EXPECTED_OUTPUTS = Expectations(

View File

@@ -20,7 +20,7 @@ from unittest.util import safe_repr
from parameterized import parameterized
from transformers import AutoTokenizer, MambaConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -32,10 +32,10 @@ if is_torch_available():
import torch
from transformers import (
MambaCache,
MambaForCausalLM,
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
class MambaModelTester:
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_config(self):
self.config_tester.run_common_tests()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
)
@unittest.skip("Mamba models do not support DDP.")
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class MambaIntegrationTests(unittest.TestCase):
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
torch_device
)
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)

View File

@@ -62,8 +62,6 @@ if is_torch_available():
TEST_CACHE_IMPLEMENTATIONS = [
cache_name
for cache_name in ALL_CACHE_IMPLEMENTATIONS
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
if cache_name != "mamba"
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
if cache_name != "offloaded_hybrid"
]