Fix slow tests for important models to be compatible with A10 runners (#29905)
* fix mistral and mixtral * add pdb * fix mixtral tesst * fix * fix mistral ? * add fix gemma * fix mistral * fix * test * anoter test * fix * fix * fix mistral tests * fix them again * final fixes for mistral * fix padding right * fix whipser fa2 * fix * fix * fix gemma * test * fix llama * fix * fix * fix llama gemma * add class attribute * fix CI * clarify whisper * compute_capability * rename names in some comments * Add # fmt: skip * make style * Update tests/models/mistral/test_modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * update --------- Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -3245,6 +3245,7 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
@@ -3338,6 +3339,7 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
@@ -3427,6 +3429,7 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_generate_left_padding(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
@@ -3470,6 +3473,7 @@ class ModelTesterMixin:
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@is_flaky
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -3888,19 +3892,20 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
batch_size = dummy_attention_mask.shape[0]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
|
||||
|
||||
# To avoid errors with padding_side=="right"
|
||||
if is_padding_right:
|
||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
@@ -3916,6 +3921,9 @@ class ModelTesterMixin:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
|
||||
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
|
||||
# with attention mask
|
||||
_ = model(
|
||||
|
||||
Reference in New Issue
Block a user