Generate: fix speculative decoding (#28166)

Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
Joao Gante
2023-12-20 18:55:35 +00:00
committed by GitHub
parent 01c081d138
commit 45b70384a7
5 changed files with 90 additions and 72 deletions

View File

@@ -21,7 +21,7 @@ import unittest
import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
@@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = (
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
)
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
set_seed(0)
generated_ids = model.generate(
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()