From 0fa46524ac2f6e564c1ca14a60761a09b4fbdfd5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 24 Mar 2023 11:33:16 +0000 Subject: [PATCH] Generate: Add GPTNeoX integration test (#22346) --- .../models/gpt_neox/test_modeling_gpt_neox.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 89765e0561..1798f01358 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -17,8 +17,8 @@ import unittest -from transformers import GPTNeoXConfig, is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -232,3 +232,28 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @unittest.skip(reason="Feed forward chunking is not implemented") def test_feed_forward_chunking(self): pass + + +@require_torch +class GPTNeoXLanguageGenerationTest(unittest.TestCase): + @slow + def test_lm_generate_codegen(self): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") + for checkpointing in [True, False]: + model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped") + + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() + model.to(torch_device) + + inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) + expected_output = ( + "My favorite food is the chicken and rice.\n\nI love to cook and bake. I love to cook and bake" + ) + + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) + output_str = tokenizer.batch_decode(output_ids)[0] + + self.assertEqual(output_str, expected_output)