[Docs] More clarifications on BT + FA (#25823)
This commit is contained in:
@@ -74,7 +74,7 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to("cuda")
|
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
|
||||||
# convert the model to BetterTransformer
|
# convert the model to BetterTransformer
|
||||||
model.to_bettertransformer()
|
model.to_bettertransformer()
|
||||||
|
|
||||||
@@ -99,6 +99,8 @@ try using the PyTorch nightly version, which may have a broader coverage for Fla
|
|||||||
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Or make sure your model is correctly casted in float16 or bfloat16
|
||||||
|
|
||||||
|
|
||||||
Have a look at [this detailed blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to read more about what is possible to do with `BetterTransformer` + SDPA API.
|
Have a look at [this detailed blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to read more about what is possible to do with `BetterTransformer` + SDPA API.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user