From ca51499248b986ebf3991848234ef2d8bc81a36a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 15 Aug 2023 11:47:47 -0400 Subject: [PATCH] Make training args fully immutable (#25435) * Make training args fully immutable * Working tests, PyTorch * In test_trainer * during testing * Use proper dataclass way * Fix test * Another one * Fix tf * Lingering slow * Exception * Clean --- examples/pytorch/image-pretraining/run_mae.py | 16 ++++++------ .../summarization/run_summarization.py | 13 ++++------ .../research_projects/mlm_wwm/run_mlm_wwm.py | 3 ++- .../tensorflow/language-modeling/run_clm.py | 5 ++-- .../tensorflow/language-modeling/run_mlm.py | 5 ++-- src/transformers/training_args.py | 12 ++++++++- tests/trainer/test_trainer.py | 5 ++-- tests/trainer/test_trainer_distributed.py | 25 +++++++++++++++---- 8 files changed, 54 insertions(+), 30 deletions(-) diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index 1c269fba3a..0967e9b090 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -163,6 +163,15 @@ class CustomTrainingArguments(TrainingArguments): default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} ) + def __post_init__(self): + # Compute absolute learning rate while args are mutable + super().__post_init__() + if self.base_learning_rate is not None: + total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size + delattr(self, "_frozen") + self.learning_rate = self.base_learning_rate * total_train_batch_size / 256 + setattr(self, "_frozen", True) + def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -353,13 +362,6 @@ def main(): # Set the validation transforms ds["validation"].set_transform(preprocess_images) - # Compute absolute learning rate - total_train_batch_size = ( - training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size - ) - if training_args.base_learning_rate is not None: - training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256 - # Initialize our trainer trainer = Trainer( model=model, diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index d145a8549d..e19ecf815a 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -18,6 +18,7 @@ Fine-tuning the library models for sequence to sequence. """ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. +import dataclasses import logging import os import sys @@ -674,14 +675,10 @@ def main(): return result # Override the decoding parameters of Seq2SeqTrainer - training_args.generation_max_length = ( - training_args.generation_max_length - if training_args.generation_max_length is not None - else data_args.val_max_target_length - ) - training_args.generation_num_beams = ( - data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams - ) + if training_args.generation_max_length is None: + training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length) + if training_args.generation_num_beams is None: + training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams) # Initialize our Trainer trainer = Seq2SeqTrainer( diff --git a/examples/research_projects/mlm_wwm/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py index f14ad5adfe..4bb138de83 100644 --- a/examples/research_projects/mlm_wwm/run_mlm_wwm.py +++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py @@ -21,6 +21,7 @@ https://huggingface.co/models?filter=fill-mask """ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +import dataclasses import json import logging import math @@ -366,7 +367,7 @@ def main(): # If we have ref files, need to avoid it removed by trainer has_ref = data_args.train_ref_file or data_args.validation_ref_file if has_ref: - training_args.remove_unused_columns = False + training_args = dataclasses.replace(training_args, remove_unused_columns=False) # Data collator # This one will take care of randomly masking the tokens. diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 1614bbd4b1..033baf5917 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -259,7 +259,6 @@ def main(): assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if training_args.output_dir is not None: - training_args.output_dir = Path(training_args.output_dir) os.makedirs(training_args.output_dir, exist_ok=True) # endregion @@ -267,8 +266,8 @@ def main(): # Detecting last checkpoint. checkpoint = None if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: - config_path = training_args.output_dir / CONFIG_NAME - weights_path = training_args.output_dir / TF2_WEIGHTS_NAME + config_path = Path(training_args.output_dir) / CONFIG_NAME + weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME if config_path.is_file() and weights_path.is_file(): checkpoint = training_args.output_dir logger.info( diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py index 671331745d..7423817f58 100755 --- a/examples/tensorflow/language-modeling/run_mlm.py +++ b/examples/tensorflow/language-modeling/run_mlm.py @@ -265,7 +265,6 @@ def main(): assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if training_args.output_dir is not None: - training_args.output_dir = Path(training_args.output_dir) os.makedirs(training_args.output_dir, exist_ok=True) if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: @@ -277,8 +276,8 @@ def main(): # Detecting last checkpoint. checkpoint = None if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: - config_path = training_args.output_dir / CONFIG_NAME - weights_path = training_args.output_dir / TF2_WEIGHTS_NAME + config_path = Path(training_args.output_dir) / CONFIG_NAME + weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME if config_path.is_file() and weights_path.is_file(): checkpoint = training_args.output_dir logger.warning( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 95732bbfc5..f27c7cd0ce 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -18,7 +18,7 @@ import json import math import os import warnings -from dataclasses import asdict, dataclass, field, fields +from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields from datetime import timedelta from enum import Enum from pathlib import Path @@ -1687,6 +1687,16 @@ class TrainingArguments: mixed_precision_dtype = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + # Finally set the `TrainingArguments` to be immutable + self._frozen = True + + def __setattr__(self, name, value): + # Once fully through the `__post_init__`, `TrainingArguments` are immutable + if not name.startswith("_") and getattr(self, "_frozen", False): + raise FrozenInstanceError(f"cannot assign to field {name}") + else: + super().__setattr__(name, value) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index db7f9bb20c..152fab898c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments): b: float = 0.0 def __post_init__(self): - super().__post_init__() # save resources not dealing with reporting (also avoids the warning when it's not set) self.report_to = [] + super().__post_init__() class RepeatDataset: @@ -529,7 +529,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.check_trained_model(trainer.model) # Re-training should restart from scratch, thus lead the same results and new seed should be used. - trainer.args.seed = 314 + args = TrainingArguments("./regression", learning_rate=0.1, seed=314) + trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel()) trainer.train() self.check_trained_model(trainer.model, alternate_seed=True) diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 5a7734b8ba..f8b59d967c 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from typing import Dict import numpy as np @@ -205,7 +206,14 @@ if __name__ == "__main__": logger.error(p.metrics) exit(1) - trainer.args.eval_accumulation_steps = 2 + training_args = dataclasses.replace(training_args, eval_accumulation_steps=2) + trainer = Trainer( + model=DummyModel(), + args=training_args, + data_collator=DummyDataCollator(), + eval_dataset=dataset, + compute_metrics=compute_metrics, + ) metrics = trainer.evaluate() logger.info(metrics) @@ -219,15 +227,22 @@ if __name__ == "__main__": logger.error(p.metrics) exit(1) - trainer.args.eval_accumulation_steps = None + training_args = dataclasses.replace(training_args, eval_accumulation_steps=None) + trainer = Trainer( + model=DummyModel(), + args=training_args, + data_collator=DummyDataCollator(), + eval_dataset=dataset, + compute_metrics=compute_metrics, + ) # Check that `dispatch_batches=False` will work on a finite iterable dataset train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) model = RegressionModel() - training_args.per_device_train_batch_size = 1 - training_args.max_steps = 1 - training_args.dispatch_batches = False + training_args = dataclasses.replace( + training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False + ) trainer = Trainer(model, training_args, train_dataset=train_dataset) trainer.train()