Fix ESM checkpoints for tests (#20436)
* Re-enable TF ESM tests, make sure we use facebook checkpoints * make fixup
This commit is contained in:
@@ -274,7 +274,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
with torch.no_grad():
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
@@ -292,7 +292,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
with torch.no_grad():
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
|
||||
Reference in New Issue
Block a user