add functions to inspect model and optimizer status to trainer.py (#29838)
* add functions to get number of params which require grad, get optimizer group for parameters and get learning rates of param groups to trainer.py * add tests and raise ValueError when optimizer is None * add second layer to test and freeze its weigths * check if torch is available before running tests * use decorator to check if torch is available Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix test indentation Co-authored-by: Zach Mueller <muellerzr@gmail.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
committed by
GitHub
parent
855b95ce34
commit
aac7099c92
@@ -3832,3 +3832,41 @@ class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
|
||||
list(HPSearchBackend),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class OptimizerAndModelInspectionTest(unittest.TestCase):
|
||||
def test_get_num_trainable_parameters(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32))
|
||||
# in_features * out_features + bias
|
||||
layer_1 = 128 * 64 + 64
|
||||
layer_2 = 64 * 32 + 32
|
||||
trainer = Trainer(model=model)
|
||||
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2)
|
||||
# Freeze the last layer
|
||||
for param in model[-1].parameters():
|
||||
param.requires_grad = False
|
||||
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1)
|
||||
|
||||
def test_get_learning_rates(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
trainer = Trainer(model=model)
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_learning_rates()
|
||||
trainer.create_optimizer()
|
||||
self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
|
||||
|
||||
def test_get_optimizer_group(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
trainer = Trainer(model=model)
|
||||
# ValueError is raised if optimizer is None
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_optimizer_group()
|
||||
trainer.create_optimizer()
|
||||
# Get groups
|
||||
num_groups = len(trainer.get_optimizer_group())
|
||||
self.assertEqual(num_groups, 2)
|
||||
# Get group of parameter
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
Reference in New Issue
Block a user