Fix repo names for ESM tests (#19451)

This commit is contained in:
Matt
2022-10-10 13:20:00 +01:00
committed by GitHub
parent cbb8a37929
commit 4107445a0f

View File

@@ -245,7 +245,7 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class EsmModelIntegrationTest(TestCasePlus): class EsmModelIntegrationTest(TestCasePlus):
@slow @slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm-2-8m") model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0] output = model(input_ids)[0]
@@ -261,7 +261,7 @@ class EsmModelIntegrationTest(TestCasePlus):
@slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = EsmModel.from_pretrained("Rocketknight1/esm-2-8m") model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
output = model(input_ids)[0] output = model(input_ids)[0]
@@ -276,7 +276,7 @@ class EsmModelIntegrationTest(TestCasePlus):
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"] keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
config = EsmConfig.from_pretrained("Rocketknight1/esm-2-8m") config = EsmConfig.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
config_tied = deepcopy(config) config_tied = deepcopy(config)
config_tied.tie_word_embeddings = True config_tied.tie_word_embeddings = True
config_untied = deepcopy(config) config_untied = deepcopy(config)