Fix slow tests for important models to be compatible with A10 runners (#29905)

* fix mistral and mixtral

* add pdb

* fix mixtral tesst

* fix

* fix mistral ?

* add fix gemma

* fix mistral

* fix

* test

* anoter test

* fix

* fix

* fix mistral tests

* fix them again

* final fixes for mistral

* fix padding right

* fix whipser fa2

* fix

* fix

* fix gemma

* test

* fix llama

* fix

* fix

* fix llama gemma

* add class attribute

* fix CI

* clarify whisper

* compute_capability

* rename names in some comments

* Add   # fmt: skip

* make style

* Update tests/models/mistral/test_modeling_mistral.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* update

---------

Co-authored-by: Younes Belkada <younesbelkada@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-04-09 13:28:54 +02:00
committed by GitHub
parent e9c23fa056
commit 08a194fcd6
6 changed files with 246 additions and 110 deletions

View File

@@ -3245,6 +3245,7 @@ class ModelTesterMixin:
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@@ -3338,6 +3339,7 @@ class ModelTesterMixin:
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
@@ -3427,6 +3429,7 @@ class ModelTesterMixin:
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky
def test_flash_attn_2_generate_left_padding(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
@@ -3470,6 +3473,7 @@ class ModelTesterMixin:
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@is_flaky
@slow
def test_flash_attn_2_generate_padding_right(self):
for model_class in self.all_generative_model_classes:
@@ -3888,19 +3892,20 @@ class ModelTesterMixin:
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = inputs_dict[model.main_input_name]
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
batch_size = dummy_attention_mask.shape[0]
if model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
# To avoid errors with padding_side=="right"
if is_padding_right:
dummy_attention_mask = torch.ones_like(dummy_input)
model = model_class.from_pretrained(
tmpdirname,
@@ -3916,6 +3921,9 @@ class ModelTesterMixin:
param.data = param.data.to(torch.float32)
if model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
# with attention mask
_ = model(