Add sdpa and FA2 for CLIP (#31940)

* Squashed commit of the following:

commit 102842cd477219b9f9bcb23a0bca3a8b92bd732f
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Fri Jul 12 18:23:52 2024 +0000

    Add model-specific sdpa tests

commit 60e4c88581abf89ec098da84ed8e92aa904c997d
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Fri Jul 12 18:20:53 2024 +0000

    Add fallback to eager (expensive operation)

commit c29033d30e7ffde4327e8a15cbbc6bee37546f80
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Thu Jul 11 17:09:55 2024 +0000

    Fix attn_implementation propagation

commit 783aed05f0f38cb2f99e758f81db6838ac55b9f8
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:05:27 2024 +0530

    style

commit e77e703ca75d00447cda277eca6b886cd32bddc0
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:04:57 2024 +0530

    add comment to explain why I had to touch forbidden codebase.

commit ab9d8849758e7773a31778ccba71588d18552623
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:03:02 2024 +0530

    fix: flax attribute access.

commit c570fc0abf9d1bd58c291aae3c7e384f995996d2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 08:23:54 2024 +0530

    fix tensorflow attribute name.

commit 32c812871cfdb268d8a6e3e2c61c5c925c8ed47e
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 07:57:10 2024 +0530

    fix attribute access.

commit 4f41a0138b6c417aed9c9332278f8bcd979cb7c2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 07:44:02 2024 +0530

    _from_config.

commit 35aed64ff602422adcf41d7f677a0a24bd9eccae
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 18:46:52 2024 +0530

    propagation of attn_implementation.

commit 4c25c19845438b1dc1d35a5adf9436151c8c5940
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:24:36 2024 +0530

    style again

commit 5f7dc5c5015c0f8116408f737e8c318d1802c80c
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:19:05 2024 +0530

    use from_config.

commit b70c409956d0359fa6ae5372275d2a20ba7e3389
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:13:43 2024 +0530

    quality

commit a7b63beff53d0fc754c6564e2a7b51731ddee49d
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 14:35:10 2024 +0200

    add benchmark numbers

commit 455b0eaea50862b8458c8f422b60fe60ae40fdcb
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:50:16 2024 +0200

    Revert "reflect feedback more"

    This reverts commit dc123e71eff60aae74d5f325f113d515d0d71117.

commit ca674829d28787349c2a9593a14e0f1d41f04ea4
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:50:05 2024 +0200

    Revert "fix"

    This reverts commit 37a1cb35b87acdc4cf7528b8b1ed6da27d244e52.

commit fab2dd8576c099eb1a3464958cb206a664d28247
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:47:46 2024 +0200

    fix

commit fbc6ae50fd6f2d36294d31e191761631b701d696
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:38:30 2024 +0200

    reflect feedback more

commit 87245bb020b2d60a89afe318a951df0159404fc9
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 08:54:34 2024 +0530

    fixes

commit 1057cc26390ee839251e7f8b3326c4207595fb23
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:49:03 2024 +0530

    don't explicit set attn_implementation in tests

commit e33f75916fc8a99f516b1cf449dbbe9d3aabda81
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:43:54 2024 +0530

    explicitly override attn_implementation in the towers.

commit 4cf41cb1bc885c39df7cb8f2a0694ebf23299235
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:38:42 2024 +0530

    import in one-line.

commit f2cc447ae9e74ccfacb448140cdf88259d4afc8c
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:34:58 2024 +0530

    move sdpa mention to usage tips.

commit 92884766c64dbb456926a3a84dd427be1349fa95
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 29 10:58:26 2024 +0530

    fix: memory allocation problem.

commit d7ffbbfe12f7750b7d0a361420f35c13e0ea787d
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 29 09:56:59 2024 +0530

    fix-copies

commit 8dfc3731cedd02e36acd3fe56bb2e6d61efd25d8
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri Apr 26 20:16:12 2024 +0530

    address arthur's comments.

commit d2ed7b4ce4ff15ae9aa4d3d0500f1544e3dcd9e9
Author: Sayak Paul <spsayakpaul@gmail.com>
Date:   Fri Apr 26 20:08:15 2024 +0530

    Apply suggestions from code review

    Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

commit 46e04361f37ded5c522ff05e9f725b9f82dce40e
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:55:27 2024 +0530

    add to docs.

commit 831629158ad40d34d8983f209afb2740ba041af2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:33:10 2024 +0530

    styling.g

commit d263a119c77314250f4b4c8469caf42559197f22
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:15:20 2024 +0530

    up

commit d44f9d3d7633d4c241a737a1bc317f791f6aedb3
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 18:40:42 2024 +0530

    handle causal and attention mask

commit 122f1d60153df6666b634a94e38d073f3f260926
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 15:18:21 2024 +0530

    test fixes.

commit 4382d8cff6fa1dee5dbcf0d06b3e2841231e36f5
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 09:39:25 2024 +0530

    fix: scaling inside sdpa.

commit 0f629989efc48b7315cf19405a81e02955efe7e5
Author: Sayak Paul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 08:14:58 2024 +0530

    Update src/transformers/models/clip/modeling_clip.py

    Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

commit 14367316877dc27ea40f767ad1aee38bbc97e4ce
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 22 16:21:36 2024 +0530

    add: sdpa support to clip.

* Remove fallback for empty attention mask (expensive operation)

* Fix typing in copies

* Add flash attention

* Add flash attention tests

* List CLIP in FA docs

* Fix embeddings attributes and tf

* [run-slow] clip

* Update clip documentation

* Remove commented code, skip compile dynamic for CLIPModel

* Fix doc

* Fix doc 2

* Remove double transpose

* Add torch version check for contiguous()

* Add comment to test mixin

* Fix copies

* Add comment for mask

* Update docs

* [run-slow] clip
This commit is contained in:
Pavel Iakubovskii
2024-07-18 06:00:37 +01:00
committed by GitHub
parent b31d595040
commit 1c37e8c1a6
14 changed files with 682 additions and 44 deletions

View File

@@ -18,21 +18,33 @@ import inspect
import os
import tempfile
import unittest
from typing import Optional, Tuple
import numpy as np
import requests
from parameterized import parameterized
from pytest import mark
import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.testing_utils import (
is_flax_available,
is_pt_flax_cross_test,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import (
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_sdpa_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
@@ -40,6 +52,7 @@ from ...test_modeling_common import (
_config_zero_init,
floats_tensor,
ids_tensor,
is_flaky,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -59,6 +72,10 @@ if is_torch_available():
)
if is_torch_sdpa_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
if is_vision_available():
from PIL import Image
@@ -167,8 +184,180 @@ class CLIPVisionModelTester:
return config, inputs_dict
class CLIPModelTesterMixin(ModelTesterMixin):
"""
Subclass of ModelTesterMixin with methods specific to testing CLIP models.
The SDPA equivalence test is overridden here because CLIP models may have test/vision/text+vision inputs,
different output logits, and are not supposed to be used or tested with padding_side="left".
"""
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
use_attention_mask_options: Tuple[Optional[str], ...] = (None, "left", "right"),
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Convert to torch dtype
dtypes = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtypes[torch_dtype]
atols = {
torch.float32: 1e-5,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
rtols = {
torch.float32: 1e-4,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
atol = atols[torch_dtype]
rtol = rtols[torch_dtype]
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
(use_mask, output_attentions, sdpa_backend, batch_size)
for use_mask in use_attention_mask_options
for output_attentions in [True, False]
for sdpa_backend in [
[SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
]
for batch_size in [1, 5]
]
fail_cases = []
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
processed_inputs = inputs_dict.copy()
# convert to torch_dtype
if "pixel_values" in processed_inputs:
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
# slice for different batch sizes
for key in ["pixel_values", "input_ids", "attention_mask"]:
if key in processed_inputs:
processed_inputs[key] = processed_inputs[key][:batch_size]
# set attention mask with left padding
if not use_mask:
processed_inputs.pop("attention_mask", None)
elif use_mask == "left":
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, :1] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
elif use_mask == "right":
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, -1:] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
else:
raise ValueError(f"Invalid value for use_mask={use_mask}")
processed_inputs["output_attentions"] = output_attentions
processed_inputs["output_hidden_states"] = True
current_case = f"use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
with torch.no_grad():
try:
with sdpa_kernel(sdpa_backend):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
except Exception as e:
fail_cases.append(f"{current_case}: {e}")
continue
keys = set(logit_keys) & set(outputs_eager.keys())
self.assertTrue(
keys, f"Keys {logit_keys} not found in outputs. Available keys: {outputs_eager.keys()}"
)
for key in keys:
try:
eager_logits = outputs_eager[key]
sdpa_logits = outputs_sdpa[key]
except KeyError:
raise KeyError(f"Key {key} not found in outputs. Available keys: {outputs_eager.keys()}")
if "hidden_state" in key and use_mask == "left":
eager_logits = eager_logits[:, 1:]
sdpa_logits = sdpa_logits[:, 1:]
elif "hidden_state" in key and use_mask == "right":
eager_logits = eager_logits[:, :-1]
sdpa_logits = sdpa_logits[:, :-1]
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
if not is_close:
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch
class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
attention_mask and seq_length.
@@ -261,6 +450,17 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "visual_projection"))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("last_hidden_state", "pooler_output", "image_embeds"),
use_attention_mask_options=(None,),
)
class CLIPTextModelTester:
def __init__(
@@ -361,7 +561,7 @@ class CLIPTextModelTester:
@require_torch
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
@@ -428,6 +628,21 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "text_projection"))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("last_hidden_state", "pooler_output", "text_embeds"),
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIPTextModel has two attention masks: `causal_attention_mask` and `attention_mask`")
class CLIPModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
@@ -479,7 +694,7 @@ class CLIPModelTester:
@require_torch
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
@@ -746,6 +961,115 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = CLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits_per_image", "logits_per_text"),
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIP text tower has two attention masks: `causal_attention_mask` and `attention_mask`")
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="CLIP model can't be compiled dynamic, error in clip_loss`")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
self.assertTrue(
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
)
self.assertTrue(
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
dummy_pixel_mask = inputs_dict["attention_mask"]
# right padding
dummy_pixel_mask[:] = 1
dummy_pixel_mask[:, -1:] = 0
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
logits_per_image_eager = outputs.logits_per_image[:, :-1]
logits_per_text_eager = outputs.logits_per_text[:, :-1]
logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
self.assertTrue(
torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
)
self.assertTrue(
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
class CLIPForImageClassificationModelTester(CLIPModelTester):
def __init__(self, parent):
@@ -769,7 +1093,7 @@ class CLIPForImageClassificationModelTester(CLIPModelTester):
@require_torch
class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
fx_compatible = False
@@ -805,6 +1129,17 @@ class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin,
def test_initialization(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits",),
use_attention_mask_options=(None,),
)
# We will verify our results on an image of cute cats
def prepare_img():