From a0803a9555ce77adfdcb22da48e5c07f5d6afbd7 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Thu, 3 Apr 2025 16:38:03 +0800 Subject: [PATCH] [tests] fix mamba integration simple inference precision issue (#37193) * fix precision issue * use float32 --- tests/models/mamba/test_modeling_mamba.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 63540575b2..bd69446e3b 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -451,7 +451,7 @@ class MambaIntegrationTests(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") tokenizer.pad_token = tokenizer.eos_token - model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16) + model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float32) model.to(device) input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device) @@ -464,14 +464,13 @@ class MambaIntegrationTests(unittest.TestCase): EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ - -55.6875, -69.8750, -49.9062, -51.7500, -57.6875, -57.9375, -56.9688, - -57.9375, -54.6875, -55.9375, -55.3125, -58.0938, -60.5625, -47.0000, - -52.0312, -49.7812, -55.9375, -57.9062, -56.7812, -57.1250, -57.3438, - -58.3125, -57.8125, -58.7812, -59.6250, -59.0938, -58.7188, -52.9375, - -53.4688, -57.3750, -56.9375, -55.7500, -53.3125, -55.8438, -57.0000, - -56.9062, -56.2188, -54.7188, -56.4375, -57.5000 - ] - ,dtype=torch.float32) # fmt: skip + -55.6909, -69.7903, -49.8981, -51.7581, -57.6544, -57.9368, -56.9591, + -57.9033, -54.6787, -55.9261, -55.3011, -58.0765, -60.5642, -47.0176, + -52.0344, -49.7836, -55.9463, -57.8957, -56.7627, -57.1080, -57.3434, + -58.3015, -57.7875, -58.7760, -59.6037, -59.0665, -58.7087, -52.9293, + -53.4654, -57.3466, -56.9294, -55.7314, -53.3141, -55.8171, -56.9879, + -56.9121, -56.2139, -54.7198, -56.4134, -57.4825 + ]) # fmt: skip torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)