Just import torch AdamW instead (#36177)

* Just import torch AdamW instead

* Update docs too

* Make AdamW undocumented

* make fixup

* Add a basic wrapper class

* Add it back to the docs

* Just remove AdamW entirely

* Remove some AdamW references

* Drop AdamW from the public init

* make fix-copies

* Cleanup some references

* make fixup

* Delete lots of transformers.AdamW references

* Remove extra references to adamw_hf
This commit is contained in:
Matt
2025-03-19 18:29:40 +00:00
committed by GitHub
parent 51bd0ceb9e
commit 9be4728af8
18 changed files with 18 additions and 174 deletions

View File

@@ -41,7 +41,6 @@ from utils_qa import postprocess_qa_predictions_with_beam_search
import transformers
from transformers import (
AdamW,
DataCollatorWithPadding,
EvalPrediction,
SchedulerType,
@@ -767,7 +766,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False

View File

@@ -33,7 +33,6 @@ from tqdm.auto import tqdm
import transformers
from transformers import (
AdamW,
SchedulerType,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
@@ -583,7 +582,7 @@ def main():
)
# Optimizer
optimizer = AdamW(
optimizer = torch.optim.AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],