Fix mistral generate for long prompt / response (#27548)

* Fix mistral generate for long prompt / response

* Add unit test

* fix linter

* fix linter

* fix test

* add assisted generation test for mistral and load the model in 4 bit + fa2
This commit is contained in:
Yanan Xie
2023-11-27 01:18:41 -08:00
committed by GitHub
parent 27b752bcf1
commit b09912c8f4
2 changed files with 31 additions and 1 deletions

View File

@@ -24,6 +24,7 @@ import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
@@ -494,3 +495,32 @@ class MistralIntegrationTest(unittest.TestCase):
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn
def test_model_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
use_flash_attention_2=True,
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
del model
backend_empty_cache(torch_device)
gc.collect()