From 4107445a0ffbb5a08587307af3980117341311c1 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 10 Oct 2022 13:20:00 +0100 Subject: [PATCH] Fix repo names for ESM tests (#19451) --- tests/models/esm/test_modeling_esm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 7bd0a36c8b..dce9cb69e8 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -245,7 +245,7 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_masked_lm(self): - model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm-2-8m") + model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] @@ -261,7 +261,7 @@ class EsmModelIntegrationTest(TestCasePlus): @slow def test_inference_no_head(self): - model = EsmModel.from_pretrained("Rocketknight1/esm-2-8m") + model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) output = model(input_ids)[0] @@ -276,7 +276,7 @@ class EsmModelIntegrationTest(TestCasePlus): keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"] - config = EsmConfig.from_pretrained("Rocketknight1/esm-2-8m") + config = EsmConfig.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D") config_tied = deepcopy(config) config_tied.tie_word_embeddings = True config_untied = deepcopy(config)