Fix flax GPT-J-6B linking model in tests (#20556)
This commit is contained in:
@@ -202,7 +202,7 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
|||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
||||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||||
|
|
||||||
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gptj-6B")
|
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
|
||||||
model.do_sample = False
|
model.do_sample = False
|
||||||
model.config.pad_token_id = model.config.eos_token_id
|
model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
@@ -323,6 +323,6 @@ class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
|||||||
@tooslow
|
@tooslow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
model = model_class_name.from_pretrained("EleutherAI/gptj-6B")
|
model = model_class_name.from_pretrained("EleutherAI/gpt-j-6B")
|
||||||
outputs = model(np.ones((1, 1)))
|
outputs = model(np.ones((1, 1)))
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user