use pytest.mark directly (#27390)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-11-09 13:32:54 +01:00
committed by GitHub
parent 791ec370d1
commit 3258ff9330
5 changed files with 14 additions and 14 deletions

View File

@@ -20,7 +20,7 @@ import inspect
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import ( from transformers import (
BarkCoarseConfig, BarkCoarseConfig,
@@ -877,7 +877,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
@@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:

View File

@@ -16,7 +16,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import DistilBertConfig, is_torch_available from transformers import DistilBertConfig, is_torch_available
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
@@ -290,7 +290,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn @require_flash_attn
@require_torch_accelerator @require_torch_accelerator
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
import torch import torch
@@ -344,7 +344,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn @require_flash_attn
@require_torch_accelerator @require_torch_accelerator
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch

View File

@@ -17,8 +17,8 @@
import unittest import unittest
import pytest
from parameterized import parameterized from parameterized import parameterized
from pytest import mark
from transformers import LlamaConfig, is_torch_available, set_seed from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -385,7 +385,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
""" """

View File

@@ -19,7 +19,7 @@ import gc
import tempfile import tempfile
import unittest import unittest
from pytest import mark import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available from transformers import AutoTokenizer, MistralConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -369,7 +369,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
import torch import torch
@@ -403,7 +403,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch

View File

@@ -21,7 +21,7 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from pytest import mark import pytest
import transformers import transformers
from transformers import WhisperConfig from transformers import WhisperConfig
@@ -800,7 +800,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference(self): def test_flash_attn_2_inference(self):
import torch import torch
@@ -845,7 +845,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
import torch import torch