FEAT / Trainer: Add adamw 4bit optimizer (#31865)

* add 4bit optimizer

* style

* fix msg

* style

* add qgalore

* Revert "add qgalore"

This reverts commit 25278e805f24d5d48eaa0638abb48de1b783a3fb.

* style

* version check
This commit is contained in:
Marc Sun
2024-08-22 15:07:09 +02:00
committed by GitHub
parent 6baa6f276a
commit c42d264549
3 changed files with 29 additions and 0 deletions

View File

@@ -99,6 +99,7 @@ from transformers.utils import (
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
is_torchao_available,
is_torchdistx_available,
)
from transformers.utils.hp_naming import TrialShortNamer
@@ -4210,6 +4211,16 @@ if is_torch_available():
dict(default_adam_kwargs, **default_anyprecision_kwargs),
)
)
if is_torchao_available():
import torchao
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW4bit,
default_adam_kwargs,
)
)
@require_torch