@@ -59,8 +59,6 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from .test_trainer_utils import TstLayer
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@@ -992,18 +990,6 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# should be about half of fp16_init
|
||||
# perfect world: fp32_init/2 == fp16_eval
|
||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
def test_no_wd_param_group(self):
|
||||
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
trainer = Trainer(model=model)
|
||||
trainer.create_optimizer_and_scheduler(10)
|
||||
# fmt: off
|
||||
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
|
||||
# fmt: on
|
||||
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
|
||||
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
|
||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -30,23 +30,8 @@ if is_torch_available():
|
||||
DistributedTensorGatherer,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
get_parameter_names
|
||||
)
|
||||
|
||||
class TstLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln1 = torch.nn.LayerNorm(hidden_size)
|
||||
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.ln2 = torch.nn.LayerNorm(hidden_size)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
|
||||
def forward(self, x):
|
||||
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
|
||||
h = torch.nn.functional.relu(self.linear2(x))
|
||||
return self.ln2(x + h + self.bias)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerUtilsTest(unittest.TestCase):
|
||||
@@ -132,12 +117,3 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||
# The indices should be a permutation of range(100)
|
||||
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
||||
|
||||
def test_get_parameter_names(self):
|
||||
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
get_parameter_names(model, [torch.nn.LayerNorm]),
|
||||
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user