Fix flax GPT-J-6B linking model in tests (#20556)
This commit is contained in:
@@ -202,7 +202,7 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gptj-6B")
|
||||
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
|
||||
model.do_sample = False
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
@@ -323,6 +323,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
||||
@tooslow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("EleutherAI/gptj-6B")
|
||||
model = model_class_name.from_pretrained("EleutherAI/gpt-j-6B")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
Reference in New Issue
Block a user