[models] respect dtype of the model when instantiating it (#12316)
* [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
|
||||
For full details on this method and other related features please refer to `Constructing Massive Models
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
||||
|
||||
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
|
||||
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
|
||||
|
||||
|
||||
Gathering Parameters
|
||||
|
||||
Reference in New Issue
Block a user