Add FA2 and sdpa support for SigLIP (#31499)

* Rebase to main

* Fix attention implementation autoset for tex and vision configs

* Fixup

* Minor fixes

* Fix copies

* Fix attention_mask for FA2

* Add eqvivalence tests for siglip

* Remove right padding test

* Uncomment flaky

* Fix import

* Add to docs

* Fix test message

* Add sdpa

* Add sdpa equivalence test

* Add siglip sdpa to docs

* Fix typing for attention output

* Add sdpa tests

* Fix signature of FA2

* Autoset attn_implementation in config

* Rename bsz -> batch_size

* Move back autoset attn method

* Mark as flaky

* Correct attention mask padding

* [run-slow] siglip

* Add FA2 and sdpa docs

* Style fix

* Remove flaky for FA2 test

* Change attention implementation set

* Change attn_implementaiton propogation

* Fix typos

* Add modality to assert message

* Add more sdpa backends in test

* [run slow] siglip

* Add math sdpa backend for all options

* [run slow] siglip
This commit is contained in:
Pavel Iakubovskii
2024-07-08 11:10:02 +01:00
committed by GitHub
parent 076e66e479
commit a177821b24
5 changed files with 680 additions and 13 deletions

View File

@@ -107,6 +107,88 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
## Combining SigLIP and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> import requests
>>> from PIL import Image
>>> from transformers import SiglipProcessor, SiglipModel
>>> device = "cuda" # the device to load the model onto
>>> model = SiglipModel.from_pretrained(
... "google/siglip-so400m-patch14-384",
... attn_implementation="flash_attention_2",
... torch_dtype=torch.float16,
... device_map=device,
... )
>>> processor = SiglipProcessor.from_pretrained("google/siglip-so400m-patch14-384")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> candidate_labels = ["2 cats", "2 dogs"]
# follows the pipeline prompt template to get same results
>>> candidate_labels = [f'This is a photo of {label}.' for label in candidate_labels]
# important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=candidate_labels, images=image, padding="max_length", return_tensors="pt")
>>> inputs.to(device)
>>> with torch.no_grad():
... with torch.autocast(device):
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
51.3% that image 0 is 'This is a photo of 2 cats.'
```
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
You may set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. Make sure you have `torch>=2.1.1`.
```python
>>> from transformers import SiglipModel
>>> model = SiglipModel.from_pretrained(
... "google/siglip-so400m-patch14-384",
... attn_implementation="sdpa",
... torch_dtype=torch.float16,
... device_map=device,
... )
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
## Expected speedups
Below is an expected speedup diagram that compares inference time between the native implementation in transformers using `google/siglip-so400m-patch14-384` checkpoint in `float16` precision and the Flash Attention 2 / SDPA version of the model using different batch sizes.
<div style="text-align: center">
<img src="https://i.imgur.com/cWm4rsn.png">
</div>
## SiglipConfig
[[autodoc]] SiglipConfig