Make training args fully immutable (#25435)
* Make training args fully immutable * Working tests, PyTorch * In test_trainer * during testing * Use proper dataclass way * Fix test * Another one * Fix tf * Lingering slow * Exception * Clean
This commit is contained in:
@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
|
||||
b: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# save resources not dealing with reporting (also avoids the warning when it's not set)
|
||||
self.report_to = []
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
class RepeatDataset:
|
||||
@@ -529,7 +529,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
|
||||
trainer.args.seed = 314
|
||||
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
|
||||
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user