Just import torch AdamW instead (#36177)

* Just import torch AdamW instead

* Update docs too

* Make AdamW undocumented

* make fixup

* Add a basic wrapper class

* Add it back to the docs

* Just remove AdamW entirely

* Remove some AdamW references

* Drop AdamW from the public init

* make fix-copies

* Cleanup some references

* make fixup

* Delete lots of transformers.AdamW references

* Remove extra references to adamw_hf
This commit is contained in:
Matt
2025-03-19 18:29:40 +00:00
committed by GitHub
parent 51bd0ceb9e
commit 9be4728af8
18 changed files with 18 additions and 174 deletions

View File

@@ -28,7 +28,6 @@ if is_torch_available():
from transformers import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
@@ -76,7 +75,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
optimizer = torch.optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
loss = criterion(w, target)
loss.backward()
@@ -114,7 +113,7 @@ class OptimizationTest(unittest.TestCase):
@require_torch
class ScheduleInitTest(unittest.TestCase):
m = nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
optimizer = torch.optim.AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol, msg=None):