[pipeline] A simple fix for half-precision & 8bit models (#21479)
* v1 fix * adapt from suggestions * make style * fix tests * add gpu tests * update docs * fix other tests * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * better fix * make fixup * better example * revert changes * proposal * more elegant solution * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -105,6 +105,8 @@ If the model is too large for a single GPU, you can set `device_map="auto"` to a
|
||||
generator(model="openai/whisper-large", device_map="auto")
|
||||
```
|
||||
|
||||
Note that if `device_map="auto"` is passed, there is no need to add the argument `device=device` when instantiating your `pipeline` as you may encounter some unexpected behavior!
|
||||
|
||||
### Batch size
|
||||
|
||||
By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases.
|
||||
@@ -257,4 +259,32 @@ sudo apt install -y tesseract-ocr
|
||||
pip install pytesseract
|
||||
```
|
||||
|
||||
</Tip>
|
||||
</Tip>
|
||||
|
||||
## Using `pipeline` on large models with 🤗 `accelerate`:
|
||||
|
||||
You can easily run `pipeline` on large models using 🤗 `accelerate`! First make sure you have installed `accelerate` with `pip install accelerate`.
|
||||
|
||||
First load your model using `device_map="auto"`! We will use `facebook/opt-1.3b` for our example.
|
||||
|
||||
```py
|
||||
# pip install accelerate
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(model="facebook/opt-1.3b", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
|
||||
```
|
||||
|
||||
You can also pass 8-bit loaded models if you install `bitsandbytes` and add the argument `load_in_8bit=True`
|
||||
|
||||
```py
|
||||
# pip install accelerate bitsandbytes
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(model="facebook/opt-1.3b", device_map="auto", model_kwargs={"load_in_8bit": True})
|
||||
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
|
||||
```
|
||||
|
||||
Note that you can replace the checkpoint with any of the Hugging Face model that supports large model loading such as BLOOM!
|
||||
Reference in New Issue
Block a user