Make gradient_checkpointing a training argument (#13657)
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -46,8 +46,8 @@ Tips:
|
||||
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
|
||||
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
|
||||
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
|
||||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
|
||||
``config.gradient_checkpointing = True``.
|
||||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
|
||||
``model.gradient_checkpointing_enable()``.
|
||||
- A notebook showing how to evaluate LED, can be accessed `here
|
||||
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
|
||||
- A notebook showing how to fine-tune LED, can be accessed `here
|
||||
|
||||
@@ -53,6 +53,7 @@ Software:
|
||||
- Tensor Parallelism
|
||||
- Low-memory Optimizers
|
||||
- fp16/bf16 (smaller data)
|
||||
- Gradient checkpointing
|
||||
|
||||
|
||||
|
||||
@@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi
|
||||
|
||||
Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)
|
||||
|
||||
|
||||
### Gradient Checkpointing
|
||||
|
||||
One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.
|
||||
|
||||
This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers.
|
||||
|
||||
To activate this feature in 🤗 Transformers for models that support it, use:
|
||||
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
or add `--gradient_checkpointing` to the Trainer arguments.
|
||||
|
||||
|
||||
### Batch sizes
|
||||
|
||||
One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model.
|
||||
|
||||
Reference in New Issue
Block a user