Generate: Add new decoding strategy "DoLa" in .generate() (#29619)
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -703,6 +703,29 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
|
||||
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
|
||||
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
|
||||
"understanding of space and time."
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
|
||||
Reference in New Issue
Block a user