Adding gradient checkpointing to GPT2 (#7446)

* GPT2 gradient checkpointing

* find_unused_parameters removed if checkpointing

* find_unused_parameters removed if checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Added a test for generation with checkpointing

* Update src/transformers/configuration_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven
2020-09-29 18:26:26 +02:00
committed by GitHub
parent 52e8392b7e
commit 9e9a1fb8c7
4 changed files with 79 additions and 40 deletions

View File

@@ -679,8 +679,10 @@ class Trainer:
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())