add functions to inspect model and optimizer status to trainer.py (#29838)

* add functions to get number of params which require grad, get optimizer group for parameters and get learning rates of param groups to trainer.py

* add tests and raise ValueError when optimizer is None

* add second layer to test and freeze its weigths

* check if torch is available before running tests

* use decorator to check if torch is available

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix test indentation

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Christopher Keibel
2024-03-28 11:37:16 +01:00
committed by GitHub
parent 855b95ce34
commit aac7099c92
2 changed files with 68 additions and 0 deletions

View File

@@ -3832,3 +3832,41 @@ class HyperParameterSearchBackendsTest(unittest.TestCase):
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
list(HPSearchBackend),
)
@require_torch
class OptimizerAndModelInspectionTest(unittest.TestCase):
def test_get_num_trainable_parameters(self):
model = nn.Sequential(nn.Linear(128, 64), nn.Linear(64, 32))
# in_features * out_features + bias
layer_1 = 128 * 64 + 64
layer_2 = 64 * 32 + 32
trainer = Trainer(model=model)
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2)
# Freeze the last layer
for param in model[-1].parameters():
param.requires_grad = False
self.assertEqual(trainer.get_num_trainable_parameters(), layer_1)
def test_get_learning_rates(self):
model = nn.Sequential(nn.Linear(128, 64))
trainer = Trainer(model=model)
with self.assertRaises(ValueError):
trainer.get_learning_rates()
trainer.create_optimizer()
self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
def test_get_optimizer_group(self):
model = nn.Sequential(nn.Linear(128, 64))
trainer = Trainer(model=model)
# ValueError is raised if optimizer is None
with self.assertRaises(ValueError):
trainer.get_optimizer_group()
trainer.create_optimizer()
# Get groups
num_groups = len(trainer.get_optimizer_group())
self.assertEqual(num_groups, 2)
# Get group of parameter
param = next(model.parameters())
group = trainer.get_optimizer_group(param)
self.assertIn(param, group["params"])