Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller
2023-09-01 11:24:12 -04:00
committed by GitHub
parent 69c5b8f186
commit be0e189bd3
9 changed files with 31 additions and 58 deletions

View File

@@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
)
def __post_init__(self):
# Compute absolute learning rate while args are mutable
super().__post_init__()
if self.base_learning_rate is not None:
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
delattr(self, "_frozen")
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
setattr(self, "_frozen", True)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
@@ -362,6 +353,13 @@ def main():
# Set the validation transforms
ds["validation"].set_transform(preprocess_images)
# Compute absolute learning rate
total_train_batch_size = (
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
if training_args.base_learning_rate is not None:
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
# Initialize our trainer
trainer = Trainer(
model=model,

View File

@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging
import os
import sys
@@ -675,10 +674,14 @@ def main():
return result
# Override the decoding parameters of Seq2SeqTrainer
if training_args.generation_max_length is None:
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
if training_args.generation_num_beams is None:
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(

View File

@@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import dataclasses
import json
import logging
import math
@@ -367,7 +366,7 @@ def main():
# If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref:
training_args = dataclasses.replace(training_args, remove_unused_columns=False)
training_args.remove_unused_columns = False
# Data collator
# This one will take care of randomly masking the tokens.

View File

@@ -259,6 +259,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
# endregion
@@ -266,8 +267,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.info(

View File

@@ -265,6 +265,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
@@ -276,8 +277,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.warning(