Add torch_empty_cache_steps to TrainingArguments (#31546)
* Add torch_empty_cache_steps to TrainingArguments * Fix formatting * Add torch_empty_cache_steps to docs on single gpu training * Remove check for torch_empty_cache_steps <= max_steps * Captalize Tip * Be device agnostic * Fix linting
This commit is contained in:
@@ -41,21 +41,22 @@ hyperparameter tuning, you should determine which batch size yields the best res
|
||||
|
||||
The methods and tools covered in this guide can be classified based on the effect they have on the training process:
|
||||
|
||||
| Method/tool | Improves training speed | Optimizes memory utilization |
|
||||
|:-----------------------------------------------------------|:------------------------|:-----------------------------|
|
||||
| [Batch size choice](#batch-size-choice) | Yes | Yes |
|
||||
| [Gradient accumulation](#gradient-accumulation) | No | Yes |
|
||||
| [Gradient checkpointing](#gradient-checkpointing) | No | Yes |
|
||||
| [Mixed precision training](#mixed-precision-training) | Yes | (No) |
|
||||
| [Optimizer choice](#optimizer-choice) | Yes | Yes |
|
||||
| [Data preloading](#data-preloading) | Yes | No |
|
||||
| [DeepSpeed Zero](#deepspeed-zero) | No | Yes |
|
||||
| [torch.compile](#using-torchcompile) | Yes | No |
|
||||
| [Parameter-Efficient Fine Tuning (PEFT)](#using--peft) | No | Yes |
|
||||
| Method/tool | Improves training speed | Optimizes memory utilization |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------|:-----------------------------|
|
||||
| [Batch size choice](#batch-size-choice) | Yes | Yes |
|
||||
| [Gradient accumulation](#gradient-accumulation) | No | Yes |
|
||||
| [Gradient checkpointing](#gradient-checkpointing) | No | Yes |
|
||||
| [Mixed precision training](#mixed-precision-training) | Yes | Maybe* |
|
||||
| [torch_empty_cache_steps](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.torch_empty_cache_steps) | No | Yes |
|
||||
| [Optimizer choice](#optimizer-choice) | Yes | Yes |
|
||||
| [Data preloading](#data-preloading) | Yes | No |
|
||||
| [DeepSpeed Zero](#deepspeed-zero) | No | Yes |
|
||||
| [torch.compile](#using-torchcompile) | Yes | No |
|
||||
| [Parameter-Efficient Fine Tuning (PEFT)](#using--peft) | No | Yes |
|
||||
|
||||
<Tip>
|
||||
|
||||
Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a
|
||||
*Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a
|
||||
large model and a small batch size, the memory use will be larger.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -221,6 +221,11 @@ if is_accelerate_available():
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
GradientAccumulationPlugin,
|
||||
is_mlu_available,
|
||||
is_mps_available,
|
||||
is_npu_available,
|
||||
is_torch_version,
|
||||
is_xpu_available,
|
||||
load_fsdp_model,
|
||||
load_fsdp_optimizer,
|
||||
save_fsdp_model,
|
||||
@@ -3307,6 +3312,20 @@ class Trainer:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
del inputs
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
if is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_version(">=", "2.0") and is_mps_available():
|
||||
torch.mps.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
|
||||
@@ -267,6 +267,15 @@ class TrainingArguments:
|
||||
eval_delay (`float`, *optional*):
|
||||
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
|
||||
eval_strategy.
|
||||
torch_empty_cache_steps (`int`, *optional*):
|
||||
Number of steps to wait before calling `torch.<device>.empty_cache()`. If left unset or set to None, cache will not be emptied.
|
||||
|
||||
<Tip>
|
||||
|
||||
This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372).
|
||||
|
||||
</Tip>
|
||||
|
||||
learning_rate (`float`, *optional*, defaults to 5e-5):
|
||||
The initial learning rate for [`AdamW`] optimizer.
|
||||
weight_decay (`float`, *optional*, defaults to 0):
|
||||
@@ -851,6 +860,15 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
torch_empty_cache_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of steps to wait before calling `torch.<device>.empty_cache()`."
|
||||
"This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)."
|
||||
"If left unset or set to None, cache will not be emptied."
|
||||
},
|
||||
)
|
||||
|
||||
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||
@@ -1532,6 +1550,12 @@ class TrainingArguments:
|
||||
if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
|
||||
self.do_eval = True
|
||||
|
||||
if self.torch_empty_cache_steps is not None:
|
||||
if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0):
|
||||
raise ValueError(
|
||||
f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
|
||||
)
|
||||
|
||||
# eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
|
||||
if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
|
||||
if self.logging_steps > 0:
|
||||
|
||||
Reference in New Issue
Block a user