[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
Adafactor,
|
||||
@@ -70,7 +71,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
def test_adam_w(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
for _ in range(100):
|
||||
@@ -84,7 +85,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
def test_adafactor(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = Adafactor(
|
||||
params=[w],
|
||||
@@ -109,7 +110,7 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = torch.nn.Linear(50, 50) if is_torch_available() else None
|
||||
m = nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
|
||||
Reference in New Issue
Block a user