Generate: fix speculative decoding (#28166)
Co-authored-by: Merve Noyan <merveenoyan@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_bitsandbytes,
|
||||
@@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
||||
)
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
set_seed(0)
|
||||
generated_ids = model.generate(
|
||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user