[trainer] deepspeed integration (#9211)

* deepspeed integration

* style

* add test

* ds wants to do its own backward

* fp16 assert

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* for clarity extract what args are being passed to deepspeed

* introduce the concept of self.wrapped_model

* s/self.wrapped_model/self.model_wrapped/

* complete transition to self.wrapped_model / self.model

* fix

* doc

* give ds its own init

* add custom overrides, handle bs correctly

* fix test

* clean up model_init logic, fix small bug

* complete fix

* collapse --deepspeed_config into --deepspeed

* style

* start adding doc notes

* style

* implement hf2ds optimizer and scheduler configuration remapping

* oops

* call get_num_training_steps absolutely when needed

* workaround broken auto-formatter

* deepspeed_config arg is no longer needed - fixed in deepspeed master

* use hf's fp16 args in config

* clean

* start on the docs

* rebase cleanup

* finish up --fp16

* clarify the supported stages

* big refactor thanks to discovering deepspeed.init_distributed

* cleanup

* revert fp16 part

* add checkpoint-support

* more init ds into integrations

* extend docs

* cleanup

* unfix docs

* clean up old code

* imports

* move docs

* fix logic

* make it clear which file it's referring to

* document nodes/gpus

* style

* wrong format

* style

* deepspeed handles gradient clipping

* easier to read

* major doc rewrite

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* docs

* switch to AdamW optimizer

* style

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* clarify doc

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Stas Bekman
2021-01-12 19:05:18 -08:00
committed by GitHub
parent 5f6721032a
commit 2df34f4aba
7 changed files with 741 additions and 57 deletions

View File

@@ -278,45 +278,6 @@ pass it to the trainer.
Finally, you can view the results, including any calculated metrics, by launching tensorboard in your specified
``logging_dir`` directory.
Trainer Integrations
-----------------------------------------------------------------------------------------------------------------------
The trainer is being extended to support experimental libraries that may dramatically improve your training time and
fit bigger models.
The main part that is being integrated at the moment is based on the paper `ZeRO: Memory Optimizations Toward Training
Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He
<https://arxiv.org/abs/1910.02054>`__.
You can already deploy the following features from this paper:
* Optimizer State Sharding
* Gradient Sharding
using the `--sharded_ddp` trainer argument. This is implemented via `fairscale
<https://github.com/facebookresearch/fairscale/>`__, so you will have to install this library.
This feature requires distributed training (so multiple GPUs) and is not implemented for TPUs.
For example here is how you could use it for `finetune_trainer.py`:
.. code-block:: bash
cd examples/seq2seq
python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py \
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
--output_dir output_dir --overwrite_output_dir \
--do_train --n_train 500 --num_train_epochs 1 \
--per_device_train_batch_size 1 --freeze_embeds \
--src_lang en_XX --tgt_lang ro_RO --task translation \
--fp16 --sharded_ddp
Note that it works with `--fp16` too, to make things even faster.
One of the main benefits of enabling `--sharded_ddp` is that it uses a lot less GPU memory, so you should be able to
use significantly larger batch sizes using the same hardware (e.g. 3x or bigger).
Eventually more parts will be supported via integrating `DeepSpeed <https://github.com/microsoft/DeepSpeed>`__.
.. _additional-resources: