[models] respect dtype of the model when instantiating it (#12316)

* [models] respect dtype of the model when instantiating it

* cleanup

* cleanup

* rework to handle non-float dtype

* fix

* switch to fp32 tiny model

* improve

* use dtype.is_floating_point

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix the doc

* recode to use explicit torch_dtype_auto_detect, torch_dtype args

* docs and tweaks

* docs and tweaks

* docs and tweaks

* merge 2 args, add docs

* fix

* fix

* better doc

* better doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-06-28 20:11:21 -07:00
committed by GitHub
parent 31c3e7e75b
commit 7682e97702
8 changed files with 221 additions and 26 deletions

View File

@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
For full details on this method and other related features please refer to `Constructing Massive Models
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
Gathering Parameters