From fae0f3dde83b7a54441f7a5bb0fc45d354fe81ce Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Mon, 17 Feb 2025 18:10:33 +0800 Subject: [PATCH] [tests] fix `EsmModelIntegrationTest::test_inference_bitsandbytes` (#36225) fix failed test --- tests/models/esm/test_modeling_esm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 7504ec2462..7be71c22c7 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -335,13 +335,13 @@ class EsmModelIntegrationTest(TestCasePlus): def test_inference_bitsandbytes(self): model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True) - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device) # Just test if inference works with torch.no_grad(): _ = model(input_ids)[0] model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True) - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device) # Just test if inference works _ = model(input_ids)[0]