Add ESMFold (#19977)
* initial commit * First draft that gets outputs without crashing! * Add all the ported openfold dependencies * testing * Restructure config files for ESMFold * Debugging to find output discrepancies * Mainly style * Make model runnable without extra deps * Remove utils and merge them to the modeling file * Use correct gelu and remove some debug prints * More cleanup * Update esm docs * Update conversion script to support ESMFold properly * Port some top-level changes from ESMFold repo * Expand EsmFold docstrings * Make attention_mask optional (default to all 1s) * Add inference test for ESMFold * Use config and not n kwargs * Add modeling output class * Remove einops * Remove chunking in ESM FFN * Update tests for ESMFold * Quality * REpo consistency * Remove tree dependency from ESMFold * make fixup * Add an error in case my structure map function breaks later * Remove needless code * Stop auto-casting the LM to float16 so CPU tests pass * Stop auto-casting the LM to float16 so CPU tests pass * Final test updates * Split test file * Copyright and quality * Unpin PyTorch to see built doc * Fix config file to_dict() method * Add some docstrings to the output * Skip TF checkpoint tests for ESM until we reupload those * make fixup * More docstrings * Unpin to get even with main * Flag example to write Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -20,7 +20,6 @@ import unittest
|
||||
from transformers import EsmConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@@ -49,7 +48,7 @@ class EsmModelTester:
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = False
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.vocab_size = 33
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
@@ -145,7 +144,7 @@ class EsmModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
class EsmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_mismatched_shapes = False
|
||||
|
||||
@@ -253,28 +252,32 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 33
|
||||
vocab_size = 33
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
expected_slice = torch.tensor(
|
||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
with torch.no_grad():
|
||||
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.tensor(
|
||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user