Fix Mamba slow path bug with dtype mismatch. (#32691)
* Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test.
This commit is contained in:
committed by
GitHub
parent
570c89625b
commit
c269c5c74d
@@ -1797,7 +1797,7 @@ class MambaCache:
|
|||||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||||
|
|
||||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||||
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
|
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
|
||||||
self.conv_states[layer_idx].zero_()
|
self.conv_states[layer_idx].zero_()
|
||||||
self.conv_states[layer_idx] += conv_state
|
self.conv_states[layer_idx] += conv_state
|
||||||
return self.conv_states[layer_idx]
|
return self.conv_states[layer_idx]
|
||||||
|
|||||||
@@ -421,6 +421,30 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_beam_sample_generate(self):
|
def test_beam_sample_generate(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_dtype_mismatch_handled_in_cache(self):
|
||||||
|
config, input_ids, *args = self.model_tester.prepare_config_and_inputs()
|
||||||
|
model = MambaModel(config)
|
||||||
|
model.to(torch_device).to(torch.float16)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Create cache with float32 dtype
|
||||||
|
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)
|
||||||
|
|
||||||
|
# If code is correct, no error occurs and test passes
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
cache_params=cache_params,
|
||||||
|
use_cache=True,
|
||||||
|
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
self.assertIsNotNone(outputs.last_hidden_state)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.last_hidden_state.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MambaIntegrationTests(unittest.TestCase):
|
class MambaIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user