[LlamaFamiliy] add a tip about dtype (#25794)
* add a warning=True tip to the Llama2 doc * code llama needs a tip too * doc nit * build PR doc * doc nits Co-authored-by: Lysandre <lysandre@huggingface.co> --------- Co-authored-by: Lysandre <lysandre@huggingface.co>
This commit is contained in:
@@ -26,6 +26,17 @@ The abstract from the paper is the following:
|
||||
|
||||
Checkout all Llama2 models [here](https://huggingface.co/models?search=llama2)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16. The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
|
||||
The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`) and finally, if there is a `torch_dtype` provided in the config, it will be used.
|
||||
|
||||
Training the model in `float16` is not recommended and known to produce `nan`, as such the model should be trained in `bfloat16`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Tips:
|
||||
|
||||
- Weights for the Llama2 models can be obtained by filling out [this form](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
|
||||
Reference in New Issue
Block a user