Add torch.compile Support For Mamba (#31247)

* modify mamba cache

* set up cache

* add test

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* use_cache_position

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* [run-slow] mamba

* fix

* cache in generate

* [run-slow] mamba

* address comments

* [run-slow] mamba

* [run-slow] mamba

* address comments

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix

* [run-slow] mamba

* fix cache name

* [run-slow] mamba
This commit is contained in:
Longjie Zheng
2024-07-18 11:54:54 -04:00
committed by GitHub
parent 4c040aba02
commit c75969ee28
4 changed files with 225 additions and 85 deletions

View File

@@ -187,11 +187,20 @@ class MambaModelTester:
outputs = model(input_ids)
output_whole = outputs.last_hidden_state
outputs = model(input_ids[:, :-1], use_cache=True)
outputs = model(
input_ids[:, :-1],
use_cache=True,
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
)
output_one = outputs.last_hidden_state
# Using the state computed on the first inputs, we will get the same output
outputs = model(input_ids[:, -1:], cache_params=outputs.cache_params)
outputs = model(
input_ids[:, -1:],
use_cache=True,
cache_params=outputs.cache_params,
cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device),
)
output_two = outputs.last_hidden_state
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
@@ -207,11 +216,13 @@ class MambaModelTester:
# create cache
cache = model(input_ids, use_cache=True).cache_params
cache.seqlen_offset = 0
cache.reset()
# use cache
token_emb = model.embeddings(input_ids)
outputs = model.layers[0].mixer.slow_forward(token_emb, cache)
outputs = model.layers[0].mixer.slow_forward(
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
)
loss = torch.log(1 + torch.abs(outputs.sum()))
self.parent.assertEqual(loss.shape, ())
@@ -508,3 +519,21 @@ class MambaIntegrationTests(unittest.TestCase):
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
@slow
def test_compile_mamba_cache(self):
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16).to(
torch_device
)
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)