Fix ESM checkpoints for tests (#20436)
* Re-enable TF ESM tests, make sure we use facebook checkpoints * make fixup
This commit is contained in:
@@ -274,7 +274,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
with torch.no_grad():
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
@@ -292,7 +292,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
with torch.no_grad():
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
|
||||
@@ -247,7 +247,7 @@ class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_protein_folding(self):
|
||||
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
|
||||
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float()
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
position_outputs = model(input_ids)["positions"]
|
||||
|
||||
@@ -254,9 +254,9 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_tf
|
||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = TFEsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
@@ -264,13 +264,19 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(list(output.numpy().shape), expected_shape)
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = tf.constant(
|
||||
[[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]]
|
||||
[
|
||||
[
|
||||
[8.920963, -10.591399, -6.467397],
|
||||
[-6.3980846, -13.913257, -1.1291938],
|
||||
[-7.7815733, -13.951929, -3.7438734],
|
||||
]
|
||||
]
|
||||
)
|
||||
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
||||
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model = TFEsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||
|
||||
input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
output = model(input_ids)[0]
|
||||
@@ -278,9 +284,9 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[0.144337, 0.541198, 0.32479298],
|
||||
[0.30328932, 0.00519154, 0.31089523],
|
||||
[0.32273883, -0.24992886, 0.34143737],
|
||||
[0.14422388, 0.5411936, 0.3249576],
|
||||
[0.30342406, 0.00549317, 0.31096306],
|
||||
[0.32278833, -0.24974644, 0.34135976],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user