Fix ESM checkpoints for tests (#20436)
* Re-enable TF ESM tests, make sure we use facebook checkpoints * make fixup
This commit is contained in:
@@ -274,7 +274,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
|||||||
@slow
|
@slow
|
||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm(self):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||||
model.eval()
|
model.eval()
|
||||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
@@ -292,7 +292,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
|||||||
@slow
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
class EsmModelIntegrationTest(TestCasePlus):
|
class EsmModelIntegrationTest(TestCasePlus):
|
||||||
@slow
|
@slow
|
||||||
def test_inference_protein_folding(self):
|
def test_inference_protein_folding(self):
|
||||||
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
|
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float()
|
||||||
model.eval()
|
model.eval()
|
||||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||||
position_outputs = model(input_ids)["positions"]
|
position_outputs = model(input_ids)["positions"]
|
||||||
|
|||||||
@@ -254,9 +254,9 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
@slow
|
||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm(self):
|
||||||
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
model = TFEsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||||
|
|
||||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
@@ -264,13 +264,19 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(list(output.numpy().shape), expected_shape)
|
self.assertEqual(list(output.numpy().shape), expected_shape)
|
||||||
# compare the actual values for a slice.
|
# compare the actual values for a slice.
|
||||||
expected_slice = tf.constant(
|
expected_slice = tf.constant(
|
||||||
[[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]]
|
[
|
||||||
|
[
|
||||||
|
[8.920963, -10.591399, -6.467397],
|
||||||
|
[-6.3980846, -13.913257, -1.1291938],
|
||||||
|
[-7.7815733, -13.951929, -3.7438734],
|
||||||
|
]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
||||||
|
|
||||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
model = TFEsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
||||||
|
|
||||||
input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
@@ -278,9 +284,9 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_slice = tf.constant(
|
expected_slice = tf.constant(
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
[0.144337, 0.541198, 0.32479298],
|
[0.14422388, 0.5411936, 0.3249576],
|
||||||
[0.30328932, 0.00519154, 0.31089523],
|
[0.30342406, 0.00549317, 0.31096306],
|
||||||
[0.32273883, -0.24992886, 0.34143737],
|
[0.32278833, -0.24974644, 0.34135976],
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user