Refactor MambaCache to modeling_mamba.py (#38086)
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162c7c5c5ce8a73ccf48cffc8a96343ebb33. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354fcab630de5899321f5daced5fb312c2a2. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf160ee251536f07c40b8a0857d499e70db6. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This commit is contained in:
committed by
GitHub
parent
a419a40234
commit
1aa7256f01
@@ -26,7 +26,6 @@ from transformers.testing_utils import (
|
||||
require_torch_accelerator,
|
||||
require_torch_large_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -41,10 +40,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
FalconMambaCache,
|
||||
FalconMambaForCausalLM,
|
||||
FalconMambaModel,
|
||||
)
|
||||
from transformers.cache_utils import MambaCache
|
||||
|
||||
|
||||
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
|
||||
@@ -312,31 +311,6 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["cache_params"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_falcon_mamba_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_falcon_mamba_model(*config_and_inputs)
|
||||
@@ -411,7 +385,7 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, MambaCache): # MODIFIED PART START
|
||||
if isinstance(tuple_object, FalconMambaCache): # MODIFIED PART START
|
||||
recursive_check(tuple_object.conv_states, dict_object.conv_states)
|
||||
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
|
||||
elif isinstance(tuple_object, (list, tuple)): # MODIFIED PART END
|
||||
@@ -458,6 +432,10 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
@unittest.skip("Mamba models do not support DDP.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@@ -497,7 +475,9 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
def test_generation_4bit(self):
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
@@ -513,6 +493,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
|
||||
inputs = self.tokenizer(self.text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
print(self.tokenizer.batch_decode(out, skip_special_tokens=False)[0])
|
||||
|
||||
self.assertEqual(
|
||||
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
|
||||
@@ -543,7 +524,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.float16)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||
@@ -553,7 +534,7 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
||||
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
|
||||
|
||||
inputs["inputs_embeds"] = inputs_embeds
|
||||
out = model.generate(**inputs, max_new_tokens=20)
|
||||
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUTS = Expectations(
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest.util import safe_repr
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, MambaConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -32,10 +32,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MambaCache,
|
||||
MambaForCausalLM,
|
||||
MambaModel,
|
||||
)
|
||||
from transformers.models.mamba.modeling_mamba import MambaCache
|
||||
|
||||
|
||||
class MambaModelTester:
|
||||
@@ -279,31 +279,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["cache_params"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_mamba_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
|
||||
@@ -452,6 +427,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
|
||||
)
|
||||
|
||||
@unittest.skip("Mamba models do not support DDP.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
@@ -547,11 +526,11 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output = model.generate(input_ids, max_new_tokens=20)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
@@ -62,8 +62,6 @@ if is_torch_available():
|
||||
TEST_CACHE_IMPLEMENTATIONS = [
|
||||
cache_name
|
||||
for cache_name in ALL_CACHE_IMPLEMENTATIONS
|
||||
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
|
||||
if cache_name != "mamba"
|
||||
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
|
||||
if cache_name != "offloaded_hybrid"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user