[tests] Parameterized test_eager_matches_sdpa_inference (#36650)
This commit is contained in:
@@ -20,7 +20,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@@ -32,7 +31,6 @@ from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, Co
|
||||
from transformers.models.colpali.processing_colpali import ColPaliProcessor
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -271,14 +269,6 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest(
|
||||
"Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user