[tests] fix EsmModelIntegrationTest::test_inference_bitsandbytes (#36225)
fix failed test
This commit is contained in:
@@ -335,13 +335,13 @@ class EsmModelIntegrationTest(TestCasePlus):
|
|||||||
def test_inference_bitsandbytes(self):
|
def test_inference_bitsandbytes(self):
|
||||||
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)
|
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
|
||||||
# Just test if inference works
|
# Just test if inference works
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(input_ids)[0]
|
_ = model(input_ids)[0]
|
||||||
|
|
||||||
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True)
|
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True)
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device)
|
||||||
# Just test if inference works
|
# Just test if inference works
|
||||||
_ = model(input_ids)[0]
|
_ = model(input_ids)[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user