Support adamw_torch_8bit (#34993)

* var

* more

* test
This commit is contained in:
fzyzcjy
2025-01-21 21:17:49 +08:00
committed by GitHub
parent f82b19cb6f
commit dc10f7906a
3 changed files with 19 additions and 3 deletions

View File

@@ -5017,6 +5017,13 @@ if is_torch_available():
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_8BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW8bit,
default_adam_kwargs,
)
)
@require_torch