Fix mistral generate for long prompt / response (#27548)
* Fix mistral generate for long prompt / response * Add unit test * fix linter * fix linter * fix test * add assisted generation test for mistral and load the model in 4 bit + fa2
This commit is contained in:
@@ -364,7 +364,7 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
|
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
slicing_tokens = 1 - self.config.sliding_window
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
past_key = past_key_value[0]
|
||||||
past_value = past_key_value[1]
|
past_value = past_key_value[1]
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import pytest
|
|||||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@@ -494,3 +495,32 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
del model
|
del model
|
||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
@slow
|
||||||
|
@require_flash_attn
|
||||||
|
def test_model_7b_long_prompt(self):
|
||||||
|
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||||
|
# An input with 4097 tokens that is above the size of the sliding window
|
||||||
|
input_ids = [1] + [306, 338] * 2048
|
||||||
|
model = MistralForCausalLM.from_pretrained(
|
||||||
|
"mistralai/Mistral-7B-v0.1",
|
||||||
|
device_map="auto",
|
||||||
|
load_in_4bit=True,
|
||||||
|
use_flash_attention_2=True,
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||||
|
|
||||||
|
# Assisted generation
|
||||||
|
assistant_model = model
|
||||||
|
assistant_model.generation_config.num_assistant_tokens = 2
|
||||||
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||||||
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||||
|
|
||||||
|
del assistant_model
|
||||||
|
del model
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
gc.collect()
|
||||||
|
|||||||
Reference in New Issue
Block a user