[models] respect dtype of the model when instantiating it (#12316)
* [models] respect dtype of the model when instantiating it * cleanup * cleanup * rework to handle non-float dtype * fix * switch to fp32 tiny model * improve * use dtype.is_floating_point * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix the doc * recode to use explicit torch_dtype_auto_detect, torch_dtype args * docs and tweaks * docs and tweaks * docs and tweaks * merge 2 args, add docs * fix * fix * better doc * better doc Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
|
||||
For full details on this method and other related features please refer to `Constructing Massive Models
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
||||
|
||||
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
|
||||
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
|
||||
|
||||
|
||||
Gathering Parameters
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
..
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
@@ -38,6 +38,37 @@ PreTrainedModel
|
||||
:members:
|
||||
|
||||
|
||||
.. _from_pretrained-torch-dtype:
|
||||
|
||||
Model Instantiation dtype
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to
|
||||
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
|
||||
either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
|
||||
|
||||
or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``,
|
||||
and then ``dtype`` will be automatically derived from the model's weights:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
|
||||
|
||||
Models instantiated from scratch can also be told which ``dtype`` to use with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = T5Config.from_pretrained("t5")
|
||||
model = AutoModel.from_config(config)
|
||||
|
||||
Due to Pytorch design, this functionality is only available for floating dtypes.
|
||||
|
||||
|
||||
|
||||
ModuleUtilsMixin
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
Reference in New Issue
Block a user