Add ESMFold (#19977)
* initial commit * First draft that gets outputs without crashing! * Add all the ported openfold dependencies * testing * Restructure config files for ESMFold * Debugging to find output discrepancies * Mainly style * Make model runnable without extra deps * Remove utils and merge them to the modeling file * Use correct gelu and remove some debug prints * More cleanup * Update esm docs * Update conversion script to support ESMFold properly * Port some top-level changes from ESMFold repo * Expand EsmFold docstrings * Make attention_mask optional (default to all 1s) * Add inference test for ESMFold * Use config and not n kwargs * Add modeling output class * Remove einops * Remove chunking in ESM FFN * Update tests for ESMFold * Quality * REpo consistency * Remove tree dependency from ESMFold * make fixup * Add an error in case my structure map function breaks later * Remove needless code * Stop auto-casting the LM to float16 so CPU tests pass * Stop auto-casting the LM to float16 so CPU tests pass * Final test updates * Split test file * Copyright and quality * Unpin PyTorch to see built doc * Fix config file to_dict() method * Add some docstrings to the output * Skip TF checkpoint tests for ESM until we reupload those * make fixup * More docstrings * Unpin to get even with main * Flag example to write Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_tf
|
||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
|
||||
@@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
||||
|
||||
@slow
|
||||
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
|
||||
def test_inference_no_head(self):
|
||||
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user