Revert frozen training arguments (#25903)
* Revert frozen training arguments * TODO
This commit is contained in:
@@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
|
||||
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Compute absolute learning rate while args are mutable
|
||||
super().__post_init__()
|
||||
if self.base_learning_rate is not None:
|
||||
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
|
||||
delattr(self, "_frozen")
|
||||
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
|
||||
setattr(self, "_frozen", True)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
@@ -362,6 +353,13 @@ def main():
|
||||
# Set the validation transforms
|
||||
ds["validation"].set_transform(preprocess_images)
|
||||
|
||||
# Compute absolute learning rate
|
||||
total_train_batch_size = (
|
||||
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
if training_args.base_learning_rate is not None:
|
||||
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
|
||||
|
||||
# Initialize our trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user