diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 62a0aef6f4..6235097e26 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers.testing_utils import ( + Expectations, require_bitsandbytes, require_read_token, require_torch, @@ -215,12 +216,22 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase): @require_read_token def test_2b_sample(self): set_seed(0) - EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write P if the underlined word group is a phrase and NP if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip + expectations = Expectations( + { + (None, None): [ + "What is Deep learning ?\n\nDeep learning is the next frontier in computer vision. It is an Artificial Intelligence (AI) discipline that is rapidly being adopted across industries. The success of Deep" + ], + ("cuda", 8): [ + "What is Deep learning ?\n\nDeep learning is the next frontier in computer vision, it’s an incredibly powerful branch of artificial intelligence.\n\nWhat is Dalle?\n\nDalle is", + ], + } + ) + EXPECTED_TEXT = expectations.get_expectation() model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device) tokenizer = AutoTokenizer.from_pretrained(self.model_id) - inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device) - output = model.generate(**inputs, max_new_tokens=128, do_sample=True) + inputs = tokenizer("What is Deep learning ?", return_tensors="pt", padding=True).to(torch_device) + output = model.generate(**inputs, max_new_tokens=32, do_sample=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXT) @@ -228,7 +239,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase): @require_bitsandbytes @require_read_token def test_model_2b_8bit(self): - EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip + EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of social media on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a 3D"] # fmt: skip model = AutoModelForCausalLM.from_pretrained( "gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16