Make training args fully immutable (#25435)

* Make training args fully immutable

* Working tests, PyTorch

* In test_trainer

* during testing

* Use proper dataclass way

* Fix test

* Another one

* Fix tf

* Lingering slow

* Exception

* Clean
This commit is contained in:
Zach Mueller
2023-08-15 11:47:47 -04:00
committed by GitHub
parent f11518a542
commit ca51499248
8 changed files with 54 additions and 30 deletions

View File

@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0
def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = []
super().__post_init__()
class RepeatDataset:
@@ -529,7 +529,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314
args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import Dict
import numpy as np
@@ -205,7 +206,14 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = 2
training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
@@ -219,15 +227,22 @@ if __name__ == "__main__":
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = None
training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
# Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.dispatch_batches = False
training_args = dataclasses.replace(
training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
)
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()