From 72b19ca680b5b9fb4cef6ed8c599c48d2449cb8b Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 28 Nov 2022 13:19:28 +0000 Subject: [PATCH] Fix ESM checkpoints for tests (#20436) * Re-enable TF ESM tests, make sure we use facebook checkpoints * make fixup --- tests/models/esm/test_modeling_esm.py | 4 ++-- tests/models/esm/test_modeling_esmfold.py | 2 +- tests/models/esm/test_modeling_tf_esm.py | 22 ++++++++++++++-------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index f6c0fcafb4..8db290880e 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -274,7 +274,7 @@ class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_masked_lm(self): with torch.no_grad(): - model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") model.eval() input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] @@ -292,7 +292,7 @@ class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_no_head(self): with torch.no_grad(): - model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D") model.eval() input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) diff --git a/tests/models/esm/test_modeling_esmfold.py b/tests/models/esm/test_modeling_esmfold.py index c6dd7c5655..ed307beef1 100644 --- a/tests/models/esm/test_modeling_esmfold.py +++ b/tests/models/esm/test_modeling_esmfold.py @@ -247,7 +247,7 @@ class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase): class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_protein_folding(self): - model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float() + model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float() model.eval() input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) position_outputs = model(input_ids)["positions"] diff --git a/tests/models/esm/test_modeling_tf_esm.py b/tests/models/esm/test_modeling_tf_esm.py index 513dbb1a7b..732989387b 100644 --- a/tests/models/esm/test_modeling_tf_esm.py +++ b/tests/models/esm/test_modeling_tf_esm.py @@ -254,9 +254,9 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase): @require_tf class TFEsmModelIntegrationTest(unittest.TestCase): - @unittest.skip("Temporarily disabled as we update ESM model checkpoints") + @slow def test_inference_masked_lm(self): - model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model = TFEsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") input_ids = tf.constant([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] @@ -264,13 +264,19 @@ class TFEsmModelIntegrationTest(unittest.TestCase): self.assertEqual(list(output.numpy().shape), expected_shape) # compare the actual values for a slice. expected_slice = tf.constant( - [[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]] + [ + [ + [8.920963, -10.591399, -6.467397], + [-6.3980846, -13.913257, -1.1291938], + [-7.7815733, -13.951929, -3.7438734], + ] + ] ) self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4)) - @unittest.skip("Temporarily disabled as we update ESM model checkpoints") + @slow def test_inference_no_head(self): - model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") + model = TFEsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D") input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) output = model(input_ids)[0] @@ -278,9 +284,9 @@ class TFEsmModelIntegrationTest(unittest.TestCase): expected_slice = tf.constant( [ [ - [0.144337, 0.541198, 0.32479298], - [0.30328932, 0.00519154, 0.31089523], - [0.32273883, -0.24992886, 0.34143737], + [0.14422388, 0.5411936, 0.3249576], + [0.30342406, 0.00549317, 0.31096306], + [0.32278833, -0.24974644, 0.34135976], ] ] )