From c4e39ee59c7ccc552e67889c1b81a574d5badf2e Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Mon, 7 Jul 2025 21:13:25 +0800 Subject: [PATCH] adjust input and output texts for test_modeling_recurrent_gemma.py (#39190) * adjust input and output texts for test_modeling_recurrent_gemma.py Signed-off-by: Liu, Kaixuan * fix bug Signed-off-by: Liu, Kaixuan * adjust Signed-off-by: Liu, Kaixuan * update Expectation match Signed-off-by: Liu, Kaixuan * fix --------- Signed-off-by: Liu, Kaixuan Co-authored-by: ydshieh --- .../test_modeling_recurrent_gemma.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 62a0aef6f4..6235097e26 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers.testing_utils import ( + Expectations, require_bitsandbytes, require_read_token, require_torch, @@ -215,12 +216,22 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase): @require_read_token def test_2b_sample(self): set_seed(0) - EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write P if the underlined word group is a phrase and NP if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip + expectations = Expectations( + { + (None, None): [ + "What is Deep learning ?\n\nDeep learning is the next frontier in computer vision. It is an Artificial Intelligence (AI) discipline that is rapidly being adopted across industries. The success of Deep" + ], + ("cuda", 8): [ + "What is Deep learning ?\n\nDeep learning is the next frontier in computer vision, it’s an incredibly powerful branch of artificial intelligence.\n\nWhat is Dalle?\n\nDalle is", + ], + } + ) + EXPECTED_TEXT = expectations.get_expectation() model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(self.model_id) - inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device) - output = model.generate(**inputs, max_new_tokens=128, do_sample=True) + inputs = tokenizer("What is Deep learning ?", return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=32, do_sample=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) @@ -228,7 +239,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase): @require_bitsandbytes @require_read_token def test_model_2b_8bit(self): - EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip + EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of social media on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a 3D"] # fmt: skip model = AutoModelForCausalLM.from_pretrained( "gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16