Generate: Add new decoding strategy "DoLa" in .generate() (#29619)

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Yung-Sung Chuang
2024-07-09 09:37:38 -07:00
committed by GitHub
parent 99c0e55335
commit d094d8d9ec
7 changed files with 530 additions and 5 deletions

View File

@@ -839,7 +839,6 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
@slow
@@ -898,3 +897,24 @@ class GemmaIntegrationTest(unittest.TestCase):
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
def test_model_2b_bf16_dola(self):
model_id = "google/gemma-2b"
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXTS = [
"Hello I am doing an experiment and need to get the mass of a block. The problem is, it has no scale",
"Hi today we have the review for a <strong>2016/2017</strong> season of",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(
**inputs, max_new_tokens=20, do_sample=False, dola_layers="low", repetition_penalty=1.2
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@@ -703,6 +703,29 @@ class LlamaIntegrationTest(unittest.TestCase):
)
)
@slow
def test_model_7b_dola_generation(self):
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXT_COMPLETION = (
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
"understanding of space and time."
)
prompt = "Simply put, the theory of relativity states that "
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
)
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# greedy generation outputs
generated_ids = model.generate(
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
@require_torch_gpu
@require_read_token

View File

@@ -555,6 +555,30 @@ class MistralIntegrationTest(unittest.TestCase):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
@slow
def test_model_7b_dola_generation(self):
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
EXPECTED_TEXT_COMPLETION = (
"""My favourite condiment is 100% ketchup. I love it on everything, and Im not ash"""
)
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(
input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn