Add ESMFold (#19977)

* initial commit

* First draft that gets outputs without crashing!

* Add all the ported openfold dependencies

* testing

* Restructure config files for ESMFold

* Debugging to find output discrepancies

* Mainly style

* Make model runnable without extra deps

* Remove utils and merge them to the modeling file

* Use correct gelu and remove some debug prints

* More cleanup

* Update esm docs

* Update conversion script to support ESMFold properly

* Port some top-level changes from ESMFold repo

* Expand EsmFold docstrings

* Make attention_mask optional (default to all 1s)

* Add inference test for ESMFold

* Use config and not n kwargs

* Add modeling output class

* Remove einops

* Remove chunking in ESM FFN

* Update tests for ESMFold

* Quality

* REpo consistency

* Remove tree dependency from ESMFold

* make fixup

* Add an error in case my structure map function breaks later

* Remove needless code

* Stop auto-casting the LM to float16 so CPU tests pass

* Stop auto-casting the LM to float16 so CPU tests pass

* Final test updates

* Split test file

* Copyright and quality

* Unpin PyTorch to see built doc

* Fix config file to_dict() method

* Add some docstrings to the output

* Skip TF checkpoint tests for ESM until we reupload those

* make fixup

* More docstrings

* Unpin to get even with main

* Flag example to write

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
Matt
2022-11-01 01:32:58 +00:00
committed by GitHub
parent 4c9e0f029e
commit 7f9b7b3f0e
22 changed files with 6820 additions and 89 deletions

View File

@@ -254,7 +254,7 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):
@slow
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_masked_lm(self):
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
@@ -268,7 +268,7 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@slow
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
def test_inference_no_head(self):
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")