[docs] Redesign (#31757)

* toctree

* not-doctested.txt

* collapse sections

* feedback

* update

* rewrite get started sections

* fixes

* fix

* loading models

* fix

* customize models

* share

* fix link

* contribute part 1

* contribute pt 2

* fix toctree

* tokenization pt 1

* Add new model (#32615)

* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* "to be not" -> "not to be" (#32636)

* "to be not" -> "not to be"

* Update sam.md

* Update trainer.py

* Update modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* fix hfoption tag

* tokenization pt. 2

* image processor

* fix toctree

* backbones

* feature extractor

* fix file name

* processor

* update not-doctested

* update

* make style

* fix toctree

* revision

* make fixup

* fix toctree

* fix

* make style

* fix hfoption tag

* pipeline

* pipeline gradio

* pipeline web server

* add pipeline

* fix toctree

* not-doctested

* prompting

* llm optims

* fix toctree

* fixes

* cache

* text generation

* fix

* chat pipeline

* chat stuff

* xla

* torch.compile

* cpu inference

* toctree

* gpu inference

* agents and tools

* gguf/tiktoken

* finetune

* toctree

* trainer

* trainer pt 2

* optims

* optimizers

* accelerate

* parallelism

* fsdp

* update

* distributed cpu

* hardware training

* gpu training

* gpu training 2

* peft

* distrib debug

* deepspeed 1

* deepspeed 2

* chat toctree

* quant pt 1

* quant pt 2

* fix toctree

* fix

* fix

* quant pt 3

* quant pt 4

* serialization

* torchscript

* scripts

* tpu

* review

* model addition timeline

* modular

* more reviews

* reviews

* fix toctree

* reviews reviews

* continue reviews

* more reviews

* modular transformers

* more review

* zamba2

* fix

* all frameworks

* pytorch

* supported model frameworks

* flashattention

* rm check_table

* not-doctested.txt

* rm check_support_list.py

* feedback

* updates/feedback

* review

* feedback

* fix

* update

* feedback

* updates

* update

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Steven Liu
2025-03-03 10:33:46 -08:00
committed by GitHub
parent 6aa9888463
commit c0f8d055ce
423 changed files with 10925 additions and 14569 deletions

View File

@@ -9,46 +9,42 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
# LLM inference optimization
# Optimizing inference
Large language models (LLMs) have pushed text generation applications, such as chat and code completion models, to the next level by producing text that displays a high level of understanding and fluency. But what makes LLMs so powerful - namely their size - also presents challenges for inference.
Inference with large language models (LLMs) can be challenging because they have to store and handle billions of parameters. To load a 70B parameter [Llama 2](https://hf.co/meta-llama/Llama-2-70b-hf) model, it requires 256GB of memory for full precision weights and 128GB of memory for half-precision weights. The most powerful GPUs today - the A100 and H100 - only have 80GB of memory.
Basic inference is slow because LLMs have to be called repeatedly to generate the next token. The input sequence increases as generation progresses, which takes longer and longer for the LLM to process. LLMs also have billions of parameters, making it a challenge to store and handle all those weights in memory.
On top of the memory requirements, inference is slow because LLMs are called repeatedly to generate the next token. The input sequence increases as generation progresses, which takes longer and longer to process.
This guide will show you how to use the optimization techniques available in Transformers to accelerate LLM inference.
This guide will show you how to optimize LLM inference to accelerate generation and reduce memory usage.
> [!TIP]
> Hugging Face also provides [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a library dedicated to deploying and serving highly optimized LLMs for inference. It includes deployment-oriented optimization features not included in Transformers, such as continuous batching for increasing throughput and tensor parallelism for multi-GPU inference.
> Try out [Text Generation Inference (TGI)](https://hf.co/docs/text-generation-inference), a Hugging Face library dedicated to deploying and serving highly optimized LLMs for inference.
## Static kv-cache and `torch.compile`
## Static kv-cache and torch.compile
During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time.
LLMs compute key-value (kv) values for each input token, and it performs the same kv computation each time because the generated output becomes part of the input. However, performing the same kv computation every time is not very efficient.
To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. We have an entire guide dedicated to kv-caches [here](./kv_cache).
A *kv-cache* stores the past keys and values instead of recomputing them each time. As a result, the kv-cache is dynamic and it grows with each generation step which prevents you from taking advantage of [torch.compile](./perf_torch_compile), a powerful optimization method that fuses PyTorch code into optimized kernels.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.
The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value, so you can combine it with [torch.compile](./perf_torch_compile) for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.
> [!WARNING]
> Currently, only [Llama](./model_doc/llama2) and a few other models support static kv-cache and `torch.compile`. Check [this issue](https://github.com/huggingface/transformers/issues/28981) for a live model compatibility list.
> Follow this [issue](https://github.com/huggingface/transformers/issues/28981) to track which models (Llama, Gemma, Mistral, etc.) support a static kv-cache and torch.compile.
There are three flavors of static kv-cache usage, depending on the complexity of your task:
1. Basic usage: simply set a flag in `generation_config` (recommended);
2. Advanced usage: handle a cache object for multi-turn generation or a custom generation loop;
3. Advanced usage: compile the entire `generate` function into a single graph, if having a single graph is relevant for you.
Depending on your task, there are several ways you can use the static kv-cache.
Select the correct tab below for further instructions on each of these flavors.
1. For basic use cases, set [cache_implementation](https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation) to `"static"` (recommended).
2. For multi-turn generation or a custom generation loop, initialize and handle [`StaticCache`] directly.
3. For more unique hardware or use cases, it may be better to compile the entire [`~GenerationMixin.generate`] function into a single graph.
> [!TIP]
> Regardless of the strategy used with `torch.compile`, you can avoid shape-related recompilations if you left-pad your LLM inputs to a limited set of values. The [`pad_to_multiple_of` tokenizer flag](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) is your friend!
> Regardless of how you use the static kv-cache and torch.compile, left-pad your inputs with [pad_to_multiple_of](https://hf.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.pad_to_multiple_of) to a limited set of values to avoid shape-related recompilations.
<hfoptions id="static-kv">
<hfoption id="basic usage: generation_config">
<hfoption id="1. cache_implementation">
For this example, let's use the [Gemma](https://hf.co/google/gemma-2b) model. All we need to do is to:
1. Access the model's `generation_config` attribute and set the `cache_implementation` to "static";
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
And that's it!
1. Set the [cache_implementation](https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation) to `"static"` in a models [`GenerationConfig`].
2. Call [torch.compile](./perf_torch_compile) to compile the forward pass with the static kv-cache.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -70,17 +66,15 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```
Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. Avoiding re-compilation is critical to get the most out of `torch.compile`, and you should be aware of the following:
1. If the batch size changes or the maximum output length increases between calls, the cache will have to be reinitialized, triggering a new compilation;
2. The first couple of calls of the compiled function are slower, as the function is being compiled.
Under the hood, [`~GenerationMixin.generate`] attempts to reuse the same cache object to avoid recompilation at each call, which is critical to get the most out of [torch.compile](./perf_torch_compile). Be aware of the following to avoid triggering recompilation or if generation is slower than expected.
> [!WARNING]
> For a more advanced usage of the static cache, such as multi-turn conversations, we recommend instantiating and manipulating the cache object outside [`~GenerationMixin.generate`]. See the advanced usage tab.
1. If the batch size changes or the maximum output length increases between calls, the cache is reinitialized and recompiled.
2. The first several calls of the compiled function are slower because it is being compiled.
</hfoption>
<hfoption id="advanced usage: control Static Cache">
<hfoption id="2. StaticCache">
A [`StaticCache`] object can be passed to the model's [`~GenerationMixin.generate`] under the `past_key_values` argument. The object will retain the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, like you would do with a dynamic cache.
Directly initialize a [`StaticCache`] object and pass it to the `past_key_values` parameter in [`~GenerationMixin.generate`]. The [`StaticCache`] keeps the cache contents, so you can pass it to a new [`~GenerationMixin.generate`] call to continue generation, similar to a dynamic cache.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
@@ -118,9 +112,9 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
```
> [!TIP]
> If you want to reuse the same [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method between calls
> To reuse [`StaticCache`] on a new prompt, use [`~StaticCache.reset`] to reset the cache contents between calls.
If you want to go further down a level, the [`StaticCache`] object can also be passed to the model's forward pass under the same `past_key_values` argument. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens.
Another option for using [`StaticCache`] is to pass it to a models forward pass using the same `past_key_values` argument. This allows you to write your own custom decoding function to decode the next token given the current token, position, and cache position of previously generated tokens.
```py
from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging
@@ -153,10 +147,11 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
return new_token
```
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
3. Use `SDPBackend.MATH` in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
To enable static kv-cache and [torch.compile](./perf_torch_compile) with [`StaticCache`], follow the steps below.
1. Initialize [`StaticCache`] before using the model for inference to configure parameters like the maximum batch size and sequence length.
2. Call [torch.compile](./perf_torch_compile) on the model to compile the forward pass with the static kv-cache.
3. se SDPBackend.MATH in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
```py
from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -193,9 +188,9 @@ text
```
</hfoption>
<hfoption id="advanced usage: end-to-end generate compilation">
<hfoption id="3. compile entire generate function">
Compiling the entire `generate` function, in terms of code, is even simpler than in the basic usage: call `torch.compile` on `generate` to compile the entire function. No need to specify the use of the static cache: although it is compatible, dynamic cache (default) was faster in our benchmarks.
Compiling the entire [`~GenerationMixin.generate`] function also compiles the input preparation logit processor operations, and more, in addition to the forward pass. With this approach, you don't need to initialize [`StaticCache`] or set the [cache_implementation](https://hf.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.cache_implementation) parameter.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -215,28 +210,33 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['The theory of special relativity states 1. The speed of light is constant in all inertial reference']
```
As a result, we compile not only the model forward pass, but also all input preparation, logit processor operations, and so on. The result should be a slightly `generate` call, compared to the basic usage example, and the compiled graph may be better suited to more exotic hardware devices or use cases. However, there are severe drawbacks in using this approach:
1. Compilation is much slower;
2. All parameterization of `generate` must be done through `generation_config`;
3. Many warnings and exceptions are suppressed -- we suggest testing with its uncompiled form first;
4. Although we are working on it, it is heavily feature restricted (for instance, at the time of writing, generation does not stop if an EOS token is selected).
This usage pattern is more appropriate for unique hardware or use cases, but there are several drawbacks to consider.
1. Compilation is much slower.
2. Parameters must be configured through [`GenerationConfig`].
3. Many warnings and exceptions are suppressed. We recommend testing the uncompiled model first.
4. Many features are unavailable at the moment. For example, generation does not stop if an `EOS` token is selected.
</hfoption>
</hfoptions>
## Speculative decoding
## Decoding strategies
Decoding can also be optimized to accelerate generation. You can use a lightweight assistant model to generate candidate tokens faster than the LLM itself or you can use a variant of this decoding strategy that works especially well for input-grounded tasks.
### Speculative decoding
> [!TIP]
> For a more in-depth explanation, take a look at the [Assisted Generation: a new direction toward low-latency text generation](https://hf.co/blog/assisted-generation) blog post!
Another issue with autoregression is that for each input token you need to load the model weights each time during the forward pass. This is slow and cumbersome for LLMs which have billions of parameters. Speculative decoding alleviates this slowdown by using a second smaller and faster assistant model to generate candidate tokens that are verified by the larger LLM in a single forward pass. If the verified tokens are correct, the LLM essentially gets them for "free" without having to generate them itself. There is no degradation in accuracy because the verification forward pass ensures the same outputs are generated as if the LLM had generated them on its own.
For each input token, the model weights are loaded each time during the forward pass, which is slow and cumbersome when a model has billions of parameters. Speculative decoding alleviates this slowdown by using a second smaller and faster assistant model to generate candidate tokens that are verified by the larger model in a single forward pass. If the verified tokens are correct, the LLM essentially gets them for "free" without having to generate them itself. There is no degradation in accuracy because the verification forward pass ensures the same outputs are generated as if the LLM had generated them on its own.
To get the largest speed up, the assistant model should be a lot smaller than the LLM so that it can generate tokens quickly. The assistant and LLM model must also share the same tokenizer to avoid re-encoding and decoding tokens.
> [!WARNING]
> Speculative decoding is only supported for the greedy search and sampling decoding strategies, and it also doesn't support batched inputs.
> Speculative decoding is only supported for the greedy search and sampling decoding strategies, and it doesn't support batched inputs.
Enable speculative decoding by loading an assistant model and passing it to the [`~GenerationMixin.generate`] method.
Enable speculative decoding by loading an assistant model and passing it to [`~GenerationMixin.generate`].
<hfoptions id="spec-decoding">
<hfoption id="greedy search">
@@ -261,7 +261,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
</hfoption>
<hfoption id="sampling">
For speculative sampling decoding, add the `do_sample` and `temperature` parameters to the [`~GenerationMixin.generate`] method in addition to the assistant model.
For speculative sampling decoding, add the [do_sample](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.do_sample) and [temperature](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.temperature) parameters to [`~GenerationMixin.generate`].
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -287,7 +287,7 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Prompt lookup decoding is a variant of speculative decoding that is also compatible with greedy search and sampling. Prompt lookup works especially well for input-grounded tasks - such as summarization - where there is often overlapping words between the prompt and output. These overlapping n-grams are used as the LLM candidate tokens.
To enable prompt lookup decoding, specify the number of tokens that should be overlapping in the `prompt_lookup_num_tokens` parameter. Then you can pass this parameter to the [`~GenerationMixin.generate`] method.
To enable prompt lookup decoding, specify the number of tokens that should be overlapping in the [prompt_lookup_num_tokens](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.prompt_lookup_num_tokens) parameter. Then pass this parameter to [`~GenerationMixin.generate`].
<hfoptions id="pld">
<hfoption id="greedy decoding">
@@ -312,7 +312,7 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
</hfoption>
<hfoption id="sampling">
For prompt lookup decoding with sampling, add the `do_sample` and `temperature` parameters to the [`~GenerationMixin.generate`] method.
For prompt lookup decoding with sampling, add the [do_sample](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.do_sample) and [temperature](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig.temperature) parameters to [`~GenerationMixin.generate`].
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -333,15 +333,15 @@ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
</hfoption>
</hfoptions>
## Attention optimizations
## Attention
A known issue with transformer models is that the self-attention mechanism grows quadratically in compute and memory with the number of input tokens. This limitation is only magnified in LLMs which handles much longer sequences. To address this, try FlashAttention2 or PyTorch's scaled dot product attention (SDPA), which are more memory efficient attention implementations and can accelerate inference.
A known issue with transformer models is that the self-attention mechanism grows quadratically in compute and memory with the number of input tokens. This limitation is only magnified in LLMs which handles much longer sequences. To address this, try FlashAttention2 or PyTorch's scaled dot product attention (SDPA), which are more memory efficient attention implementations.
### FlashAttention-2
FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.
FlashAttention and [FlashAttention-2](./perf_infer_gpu_one#flashattention-2) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to the GPU memory to speed up inference. FlashAttention-2 improves on the original FlashAttention algorithm by also parallelizing over sequence length dimension and better partitioning work on the hardware to reduce synchronization and communication overhead.
To use FlashAttention-2, set `attn_implementation="flash_attention_2"` in the [`~PreTrainedModel.from_pretrained`] method.
To use FlashAttention-2, set [attn_implementation](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.attn_implementation) to `"flash_attention_2"` in [`~PreTrainedModel.from_pretrained`].
```py
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
@@ -355,105 +355,12 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
### Fine-Tuning with torch.compile and Padding-Free Data Collation
In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
```
#################### IMPORTS ###################
import math
import datasets
import dataclasses
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments
)
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
#################### MODEL LOADING WITH FLASH ATTENTION ###################
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" # Enables FlashAttention-2
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#################### DATA PREPROCESSING (PADDING-FREE) ###################
response_template = "\n### Label:"
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:] # Exclude special tokens
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids=response_template_ids,
tokenizer=tokenizer,
ignore_index=-100,
padding_free=True # Enables padding-free collation
)
def format_dataset(example):
return {
"output": example["output"] + tokenizer.eos_token
}
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
################# TRAINING CONFIGURATION ############################
train_args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.0,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=1,
include_tokens_per_second=True,
save_strategy="epoch",
output_dir="output",
torch_compile=True, # Enables torch.compile
torch_compile_backend="inductor",
torch_compile_mode="default"
)
# Convert TrainingArguments to SFTConfig
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
k: v
for k, v in train_args.to_dict().items()
if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)
####################### FINE-TUNING #####################
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_train_dataset,
data_collator=data_collator,
dataset_text_field="output",
args=training_args,
)
trainer.train()
```
### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
> [!TIP]
> SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.
> SDPA automaticallysupports FlashAttention-2 as long as you have the latest PyTorch version installed.
Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention.
@@ -473,12 +380,14 @@ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
## Quantization
Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by your GPUs memory. If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.
Quantization reduces the size of model weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by GPU memory.
If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can increase latency slightly (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.
> [!TIP]
> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating the memory required to load [Mistral-7B-v0.1](https://hf.co/mistralai/Mistral-7B-v0.1).
<iframe
src="https://hf-accelerate-model-memory-usage.hf.space"
@@ -487,7 +396,7 @@ Use the Model Memory Calculator below to estimate and compare how much memory is
height="450"
></iframe>
To load Mistral-7B-v0.1 in half-precision, set the `torch_dtype` parameter in the [`~transformers.AutoModelForCausalLM.from_pretrained`] method to `torch.bfloat16`. This requires 13.74GB of memory.
To load a model in half-precision, set the [torch_dtype](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.PreTrainedModel.from_pretrained.torch_dtype) parameter in [`~transformers.AutoModelForCausalLM.from_pretrained`] to `torch.bfloat16`. This requires 13.74GB of memory.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -498,7 +407,7 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
To load a quantized model (8-bit or 4-bit) for inference, try [bitsandbytes](https://hf.co/docs/bitsandbytes) and set the `load_in_4bit` or `load_in_8bit` parameters to `True`. Loading the model in 8-bits only requires 6.87 GB of memory.
To load a quantized model (8-bit or 4-bit), try [bitsandbytes](https://hf.co/docs/bitsandbytes) and set the [load_in_4bit](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.BitsAndBytesConfig.load_in_4bit) or [load_in_8bit](https://hf.co/docs/transformers/main/en/main_classes/text_generation#transformers.BitsAndBytesConfig.load_in_8bit) parameters to `True`. Loading the model in 8-bits only requires 6.87 GB of memory.
```py
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig