[docs] Increase visibility of torch_dtype="auto" (#35067)
* auto-dtype * feedback
This commit is contained in:
@@ -19,6 +19,7 @@ Before you begin, make sure the following libraries are installed with their lat
|
||||
pip install --upgrade torch torchao
|
||||
```
|
||||
|
||||
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -28,7 +29,7 @@ model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
|
||||
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
|
||||
Reference in New Issue
Block a user