adjust input and output texts for test_modeling_recurrent_gemma.py (#39190)

* adjust input and output texts for test_modeling_recurrent_gemma.py

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix bug

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update Expectation match

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* fix

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
kaixuanliu
2025-07-07 21:13:25 +08:00
committed by GitHub
parent 14cba7ad33
commit c4e39ee59c

View File

@@ -20,6 +20,7 @@ from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
Expectations,
require_bitsandbytes,
require_read_token,
require_torch,
@@ -215,12 +216,22 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
@require_read_token
def test_2b_sample(self):
set_seed(0)
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
expectations = Expectations(
{
(None, None): [
"What is Deep learning ?\n\nDeep learning is the next frontier in computer vision. It is an Artificial Intelligence (AI) discipline that is rapidly being adopted across industries. The success of Deep"
],
("cuda", 8): [
"What is Deep learning ?\n\nDeep learning is the next frontier in computer vision, its an incredibly powerful branch of artificial intelligence.\n\nWhat is Dalle?\n\nDalle is",
],
}
)
EXPECTED_TEXT = expectations.get_expectation()
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=128, do_sample=True)
inputs = tokenizer("What is Deep learning ?", return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=32, do_sample=True)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXT)
@@ -228,7 +239,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_read_token
def test_model_2b_8bit(self):
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of social media on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a 3D"] # fmt: skip
model = AutoModelForCausalLM.from_pretrained(
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16