Just import torch AdamW instead (#36177)
* Just import torch AdamW instead * Update docs too * Make AdamW undocumented * make fixup * Add a basic wrapper class * Add it back to the docs * Just remove AdamW entirely * Remove some AdamW references * Drop AdamW from the public init * make fix-copies * Cleanup some references * make fixup * Delete lots of transformers.AdamW references * Remove extra references to adamw_hf
This commit is contained in:
@@ -28,7 +28,6 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
Adafactor,
|
||||
AdamW,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
@@ -76,7 +75,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
optimizer = torch.optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
@@ -114,7 +113,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
optimizer = torch.optim.AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol, msg=None):
|
||||
|
||||
Reference in New Issue
Block a user