TF port of ESM (#19587)
* Partial TF port for ESM model * Add ESM-TF tests * Add the various imports for TF-ESM * TF weight conversion almost ready * Stop ignoring the decoder weights in PT * Add tests and lots of fixes * fix-copies * Fix imports, add model docs * Add get_vocab() to tokenizer * Fix vocab links for pretrained files * Allow multiple inputs with a sep * Use EOS as SEP token because ESM vocab lacks SEP * Correctly return special tokens mask from ESM tokenizer * make fixup * Stop testing unsupported embedding resizing * Handle TF bias correctly * Skip all models with slow tokenizers in the token classification test * Fixing the batch/unbatcher of pipelines to accomodate the `None` being passed around. * Fixing pipeline bug caused by slow tokenizer being different. * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/esm/modeling_tf_esm.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update set_input_embeddings and the copyright notices Co-authored-by: Your Name <you@example.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -240,6 +240,14 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
@unittest.skip("Esm does not support embedding resizing")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Esm does not support embedding resizing")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
@@ -270,24 +278,3 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def test_lm_head_ignore_keys(self):
|
||||
from copy import deepcopy
|
||||
|
||||
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
||||
config = EsmConfig.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
|
||||
config_tied = deepcopy(config)
|
||||
config_tied.tie_word_embeddings = True
|
||||
config_untied = deepcopy(config)
|
||||
config_untied.tie_word_embeddings = False
|
||||
for cls in [EsmForMaskedLM]:
|
||||
model = cls(config_tied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
||||
|
||||
# the keys should be different when embeddings aren't tied
|
||||
model = cls(config_untied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
||||
|
||||
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
||||
model.save_pretrained(self.get_auto_remove_tmp_dir())
|
||||
|
||||
Reference in New Issue
Block a user