Make training args fully immutable (#25435)

* Make training args fully immutable

* Working tests, PyTorch

* In test_trainer

* during testing

* Use proper dataclass way

* Fix test

* Another one

* Fix tf

* Lingering slow

* Exception

* Clean
This commit is contained in:
Zach Mueller
2023-08-15 11:47:47 -04:00
committed by GitHub
parent f11518a542
commit ca51499248
8 changed files with 54 additions and 30 deletions

View File

@@ -163,6 +163,15 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."} default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
) )
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
@@ -353,13 +362,6 @@ def main():
# Set the validation transforms # Set the validation transforms
ds["validation"].set_transform(preprocess_images) ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,

View File

@@ -18,6 +18,7 @@ Fine-tuning the library models for sequence to sequence.
""" """
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging import logging
import os import os
import sys import sys
@@ -674,14 +675,10 @@ def main():
return result return result
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = ( if training_args.generation_max_length is None:
training_args.generation_max_length training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
if training_args.generation_max_length is not None if training_args.generation_num_beams is None:
else data_args.val_max_target_length training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(

View File

@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=fill-mask
""" """
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json import json
import logging import logging
import math import math
@@ -366,7 +367,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer # If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref: if has_ref:
training_args.remove_unused_columns = False training_args = dataclasses.replace(training_args, remove_unused_columns=False)
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.

View File

@@ -259,7 +259,6 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
# endregion # endregion
@@ -267,8 +266,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = training_args.output_dir / CONFIG_NAME config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.info( logger.info(

View File

@@ -265,7 +265,6 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None: if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
@@ -277,8 +276,8 @@ def main():
# Detecting last checkpoint. # Detecting last checkpoint.
checkpoint = None checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = training_args.output_dir / CONFIG_NAME config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file(): if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir checkpoint = training_args.output_dir
logger.warning( logger.warning(

View File

@@ -18,7 +18,7 @@ import json
import math import math
import os import os
import warnings import warnings
from dataclasses import asdict, dataclass, field, fields from dataclasses import FrozenInstanceError, asdict, dataclass, field, fields
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@@ -1687,6 +1687,16 @@ class TrainingArguments:
mixed_precision_dtype = "bf16" mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
# Finally set the `TrainingArguments` to be immutable
self._frozen = True
def __setattr__(self, name, value):
# Once fully through the `__post_init__`, `TrainingArguments` are immutable
if not name.startswith("_") and getattr(self, "_frozen", False):
raise FrozenInstanceError(f"cannot assign to field {name}")
else:
super().__setattr__(name, value)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)

View File

@@ -139,9 +139,9 @@ class RegressionTrainingArguments(TrainingArguments):
b: float = 0.0 b: float = 0.0
def __post_init__(self): def __post_init__(self):
super().__post_init__()
# save resources not dealing with reporting (also avoids the warning when it's not set) # save resources not dealing with reporting (also avoids the warning when it's not set)
self.report_to = [] self.report_to = []
super().__post_init__()
class RepeatDataset: class RepeatDataset:
@@ -529,7 +529,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.check_trained_model(trainer.model) self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used. # Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314 args = TrainingArguments("./regression", learning_rate=0.1, seed=314)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dataclasses
from typing import Dict from typing import Dict
import numpy as np import numpy as np
@@ -205,7 +206,14 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
trainer.args.eval_accumulation_steps = 2 training_args = dataclasses.replace(training_args, eval_accumulation_steps=2)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate() metrics = trainer.evaluate()
logger.info(metrics) logger.info(metrics)
@@ -219,15 +227,22 @@ if __name__ == "__main__":
logger.error(p.metrics) logger.error(p.metrics)
exit(1) exit(1)
trainer.args.eval_accumulation_steps = None training_args = dataclasses.replace(training_args, eval_accumulation_steps=None)
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
# Check that `dispatch_batches=False` will work on a finite iterable dataset # Check that `dispatch_batches=False` will work on a finite iterable dataset
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel() model = RegressionModel()
training_args.per_device_train_batch_size = 1 training_args = dataclasses.replace(
training_args.max_steps = 1 training_args, per_device_train_batch_size=1, max_steps=1, dispatch_batches=False
training_args.dispatch_batches = False )
trainer = Trainer(model, training_args, train_dataset=train_dataset) trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train() trainer.train()