adjust input and output texts for test_modeling_recurrent_gemma.py (#39190)
* adjust input and output texts for test_modeling_recurrent_gemma.py Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix bug Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update Expectation match Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_bitsandbytes,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
@@ -215,12 +216,22 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
@require_read_token
|
||||
def test_2b_sample(self):
|
||||
set_seed(0)
|
||||
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [
|
||||
"What is Deep learning ?\n\nDeep learning is the next frontier in computer vision. It is an Artificial Intelligence (AI) discipline that is rapidly being adopted across industries. The success of Deep"
|
||||
],
|
||||
("cuda", 8): [
|
||||
"What is Deep learning ?\n\nDeep learning is the next frontier in computer vision, it’s an incredibly powerful branch of artificial intelligence.\n\nWhat is Dalle?\n\nDalle is",
|
||||
],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT = expectations.get_expectation()
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
inputs = tokenizer("Where is Paris ?", return_tensors="pt", padding=True).to(torch_device)
|
||||
output = model.generate(**inputs, max_new_tokens=128, do_sample=True)
|
||||
inputs = tokenizer("What is Deep learning ?", return_tensors="pt", padding=True).to(torch_device)
|
||||
output = model.generate(**inputs, max_new_tokens=32, do_sample=True)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXT)
|
||||
@@ -228,7 +239,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_2b_8bit(self):
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip
|
||||
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of social media on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a 3D"] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user