Tests
This commit is contained in:
@@ -59,6 +59,8 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from .test_trainer_utils import TstLayer
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@@ -990,6 +992,18 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# should be about half of fp16_init
|
||||
# perfect world: fp32_init/2 == fp16_eval
|
||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
def test_no_wd_param_group(self):
|
||||
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
trainer = Trainer(model=model)
|
||||
trainer.create_optimizer_and_scheduler(10)
|
||||
# fmt: off
|
||||
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
|
||||
# fmt: on
|
||||
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
|
||||
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
|
||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user