[Mixtral & Mistral] Add support for sdpa (#28133)

* some nits

* update test

* add support d\sd[a

* remove some dummy inputs

* all good

* style

* nits

* fixes

* fix more copies

* nits

* styling

* fix

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add a slow test just to be sure

* fixup

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Arthur
2023-12-21 12:38:22 +01:00
committed by GitHub
parent 814619f54f
commit f9a98c476c
4 changed files with 253 additions and 19 deletions

View File

@@ -38,11 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
)
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
class MixtralModelTester: