Tests
This commit is contained in:
@@ -80,6 +80,7 @@ from .trainer_pt_utils import (
|
|||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
|
get_parameter_names,
|
||||||
nested_concat,
|
nested_concat,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
@@ -613,14 +614,15 @@ class Trainer:
|
|||||||
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
||||||
"""
|
"""
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
|
||||||
|
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||||
optimizer_grouped_parameters = [
|
optimizer_grouped_parameters = [
|
||||||
{
|
{
|
||||||
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
|
||||||
"weight_decay": self.args.weight_decay,
|
"weight_decay": self.args.weight_decay,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -672,3 +672,19 @@ def save_state(self):
|
|||||||
|
|
||||||
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
||||||
self.state.save_to_json(path)
|
self.state.save_to_json(path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_names(model, forbidden_layer_types):
|
||||||
|
"""
|
||||||
|
Returns the names of the model parameters that are not inside a forbidden layer.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for name, child in model.named_children():
|
||||||
|
result += [
|
||||||
|
f"{name}.{n}"
|
||||||
|
for n in get_parameter_names(child, forbidden_layer_types)
|
||||||
|
if not isinstance(child, tuple(forbidden_layer_types))
|
||||||
|
]
|
||||||
|
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
||||||
|
result += list(model._parameters.keys())
|
||||||
|
return result
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from transformers.modeling_utils import unwrap_model
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
|
||||||
|
from .test_trainer_utils import TstLayer
|
||||||
|
|
||||||
|
|
||||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||||
|
|
||||||
@@ -991,6 +993,18 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
# perfect world: fp32_init/2 == fp16_eval
|
# perfect world: fp32_init/2 == fp16_eval
|
||||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||||
|
|
||||||
|
def test_no_wd_param_group(self):
|
||||||
|
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||||
|
trainer = Trainer(model=model)
|
||||||
|
trainer.create_optimizer_and_scheduler(10)
|
||||||
|
# fmt: off
|
||||||
|
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
|
||||||
|
# fmt: on
|
||||||
|
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
|
||||||
|
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
|
||||||
|
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||||
|
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
@@ -30,8 +30,23 @@ if is_torch_available():
|
|||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
LengthGroupedSampler,
|
LengthGroupedSampler,
|
||||||
|
get_parameter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class TstLayer(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.ln1 = torch.nn.LayerNorm(hidden_size)
|
||||||
|
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.ln2 = torch.nn.LayerNorm(hidden_size)
|
||||||
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
|
||||||
|
h = torch.nn.functional.relu(self.linear2(x))
|
||||||
|
return self.ln2(x + h + self.bias)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TrainerUtilsTest(unittest.TestCase):
|
class TrainerUtilsTest(unittest.TestCase):
|
||||||
@@ -117,3 +132,12 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
self.assertEqual(lengths[indices_process_0[0]], 50)
|
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||||
# The indices should be a permutation of range(100)
|
# The indices should be a permutation of range(100)
|
||||||
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
||||||
|
|
||||||
|
def test_get_parameter_names(self):
|
||||||
|
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
get_parameter_names(model, [torch.nn.LayerNorm]),
|
||||||
|
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|||||||
Reference in New Issue
Block a user