[Mixtral & Mistral] Add support for sdpa (#28133)

* some nits

* update test

* add support d\sd[a

* remove some dummy inputs

* all good

* style

* nits

* fixes

* fix more copies

* nits

* styling

* fix

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add a slow test just to be sure

* fixup

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Arthur
2023-12-21 12:38:22 +01:00
committed by GitHub
parent 814619f54f
commit f9a98c476c
4 changed files with 253 additions and 19 deletions

View File

@@ -28,6 +28,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@@ -528,6 +529,44 @@ class MistralIntegrationTest(unittest.TestCase):
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_torch_sdpa
def test_model_7b_long_prompt_sdpa(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
device_map="auto",
attn_implementation="sdpa",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
backend_empty_cache(torch_device)
gc.collect()
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. Im not a big"""
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = (

View File

@@ -38,11 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
)
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
class MixtralModelTester: