Fix ESM checkpoints for tests (#20436)

* Re-enable TF ESM tests, make sure we use facebook checkpoints

* make fixup
This commit is contained in:
Matt
2022-11-28 13:19:28 +00:00
committed by GitHub
parent f244a97801
commit 72b19ca680
3 changed files with 17 additions and 11 deletions

View File

@@ -247,7 +247,7 @@ class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_protein_folding(self):
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float()
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
position_outputs = model(input_ids)["positions"]