Docs: fix GaLore optimizer code example (#32249)
Docs: fix GaLore optimizer example Fix incorrect usage of GaLore optimizer in Transformers trainer code example. The GaLore optimizer uses low-rank gradient updates to reduce memory usage. GaLore is quite popular and is implemented by the authors in [https://github.com/jiaweizzhao/GaLore](https://github.com/jiaweizzhao/GaLore). A few months ago GaLore was added to the HuggingFace Transformers library in https://github.com/huggingface/transformers/pull/29588. Documentation of the Trainer module includes a few code examples of how to use GaLore. However, the `optim_targe_modules` argument to the `TrainingArguments` function is incorrect, as discussed in https://github.com/huggingface/transformers/pull/29588#issuecomment-2006289512. This pull request fixes this issue.
This commit is contained in:
@@ -278,7 +278,7 @@ args = TrainingArguments(
|
|||||||
max_steps=100,
|
max_steps=100,
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
optim="galore_adamw",
|
optim="galore_adamw",
|
||||||
optim_target_modules=["attn", "mlp"]
|
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
@@ -315,7 +315,7 @@ args = TrainingArguments(
|
|||||||
max_steps=100,
|
max_steps=100,
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
optim="galore_adamw",
|
optim="galore_adamw",
|
||||||
optim_target_modules=["attn", "mlp"],
|
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
|
||||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -359,7 +359,7 @@ args = TrainingArguments(
|
|||||||
max_steps=100,
|
max_steps=100,
|
||||||
per_device_train_batch_size=2,
|
per_device_train_batch_size=2,
|
||||||
optim="galore_adamw_layerwise",
|
optim="galore_adamw_layerwise",
|
||||||
optim_target_modules=["attn", "mlp"]
|
optim_target_modules=[r".*.attn.*", r".*.mlp.*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
|
|||||||
Reference in New Issue
Block a user