From b09912c8f452ac485933ac0f86937aa01de3c398 Mon Sep 17 00:00:00 2001 From: Yanan Xie <108375850+lorabit110@users.noreply.github.com> Date: Mon, 27 Nov 2023 01:18:41 -0800 Subject: [PATCH] Fix mistral generate for long prompt / response (#27548) * Fix mistral generate for long prompt / response * Add unit test * fix linter * fix linter * fix test * add assisted generation test for mistral and load the model in 4 bit + fa2 --- .../models/mistral/modeling_mistral.py | 2 +- tests/models/mistral/test_modeling_mistral.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 72525e665a..9c6300ab5e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -364,7 +364,7 @@ class MistralFlashAttention2(MistralAttention): if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window: - slicing_tokens = kv_seq_len - self.config.sliding_window + slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[0] past_value = past_key_value[1] diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index cedcdeb4b9..dba013b205 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -24,6 +24,7 @@ import pytest from transformers import AutoTokenizer, MistralConfig, is_torch_available from transformers.testing_utils import ( backend_empty_cache, + require_bitsandbytes, require_flash_attn, require_torch, require_torch_gpu, @@ -494,3 +495,32 @@ class MistralIntegrationTest(unittest.TestCase): del model backend_empty_cache(torch_device) gc.collect() + + @require_bitsandbytes + @slow + @require_flash_attn + def test_model_7b_long_prompt(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", + device_map="auto", + load_in_4bit=True, + use_flash_attention_2=True, + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect()