Add torch.compile Support For Mamba (#31247)
* modify mamba cache * set up cache * add test * [run-slow] mamba * [run-slow] mamba * address comments * [run-slow] mamba * use_cache_position * [run-slow] mamba * [run-slow] mamba * [run-slow] mamba * [run-slow] mamba * fix * cache in generate * [run-slow] mamba * address comments * [run-slow] mamba * [run-slow] mamba * address comments * [run-slow] mamba * fix * [run-slow] mamba * fix * [run-slow] mamba * fix cache name * [run-slow] mamba
This commit is contained in:
@@ -187,11 +187,20 @@ class MambaModelTester:
|
||||
outputs = model(input_ids)
|
||||
output_whole = outputs.last_hidden_state
|
||||
|
||||
outputs = model(input_ids[:, :-1], use_cache=True)
|
||||
outputs = model(
|
||||
input_ids[:, :-1],
|
||||
use_cache=True,
|
||||
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
|
||||
)
|
||||
output_one = outputs.last_hidden_state
|
||||
|
||||
# Using the state computed on the first inputs, we will get the same output
|
||||
outputs = model(input_ids[:, -1:], cache_params=outputs.cache_params)
|
||||
outputs = model(
|
||||
input_ids[:, -1:],
|
||||
use_cache=True,
|
||||
cache_params=outputs.cache_params,
|
||||
cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device),
|
||||
)
|
||||
output_two = outputs.last_hidden_state
|
||||
|
||||
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
|
||||
@@ -207,11 +216,13 @@ class MambaModelTester:
|
||||
|
||||
# create cache
|
||||
cache = model(input_ids, use_cache=True).cache_params
|
||||
cache.seqlen_offset = 0
|
||||
cache.reset()
|
||||
|
||||
# use cache
|
||||
token_emb = model.embeddings(input_ids)
|
||||
outputs = model.layers[0].mixer.slow_forward(token_emb, cache)
|
||||
outputs = model.layers[0].mixer.slow_forward(
|
||||
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
|
||||
)
|
||||
|
||||
loss = torch.log(1 + torch.abs(outputs.sum()))
|
||||
self.parent.assertEqual(loss.shape, ())
|
||||
@@ -508,3 +519,21 @@ class MambaIntegrationTests(unittest.TestCase):
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
@slow
|
||||
def test_compile_mamba_cache(self):
|
||||
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
output = model.generate(input_ids, max_new_tokens=20, cache_implementation="mamba")
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
Reference in New Issue
Block a user