Generate: Add GPTNeoX integration test (#22346)
This commit is contained in:
@@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import GPTNeoXConfig, is_torch_available
|
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -232,3 +232,28 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_codegen(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||||
|
for checkpointing in [True, False]:
|
||||||
|
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||||
|
|
||||||
|
if checkpointing:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
else:
|
||||||
|
model.gradient_checkpointing_disable()
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
|
||||||
|
expected_output = (
|
||||||
|
"My favorite food is the chicken and rice.\n\nI love to cook and bake. I love to cook and bake"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||||
|
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||||
|
|
||||||
|
self.assertEqual(output_str, expected_output)
|
||||||
|
|||||||
Reference in New Issue
Block a user