Refactor MambaCache to modeling_mamba.py (#38086)
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
committed by
GitHub
parent
a419a40234
commit
1aa7256f01
@@ -20,7 +20,7 @@ from unittest.util import safe_repr
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, MambaConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -32,10 +32,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MambaCache,
|
||||
MambaForCausalLM,
|
||||
MambaModel,
|
||||
)
|
||||
from transformers.models.mamba.modeling_mamba import MambaCache
|
||||
|
||||
|
||||
class MambaModelTester:
|
||||
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["cache_params"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_mamba_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
|
||||
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
|
||||
)
|
||||
|
||||
@unittest.skip("Mamba models do not support DDP.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
Reference in New Issue
Block a user