From e135a6c93125db85b291380186b840a6b488a151 Mon Sep 17 00:00:00 2001 From: Francisco Kurucz Date: Mon, 5 Dec 2022 10:00:05 -0300 Subject: [PATCH] Fix flax GPT-J-6B linking model in tests (#20556) --- tests/models/gptj/test_modeling_flax_gptj.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/gptj/test_modeling_flax_gptj.py b/tests/models/gptj/test_modeling_flax_gptj.py index 28dd654837..9a6472bc92 100644 --- a/tests/models/gptj/test_modeling_flax_gptj.py +++ b/tests/models/gptj/test_modeling_flax_gptj.py @@ -202,7 +202,7 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left") inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True) - model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gptj-6B") + model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") model.do_sample = False model.config.pad_token_id = model.config.eos_token_id @@ -323,6 +323,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes @tooslow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("EleutherAI/gptj-6B") + model = model_class_name.from_pretrained("EleutherAI/gpt-j-6B") outputs = model(np.ones((1, 1))) self.assertIsNotNone(outputs)