Revert frozen training arguments (#25903)
* Revert frozen training arguments * TODO
This commit is contained in:
@@ -163,15 +163,6 @@ class CustomTrainingArguments(TrainingArguments):
|
||||
default=1e-3, metadata={"help": "Base learning rate: absolute_lr = base_lr * total_batch_size / 256."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Compute absolute learning rate while args are mutable
|
||||
super().__post_init__()
|
||||
if self.base_learning_rate is not None:
|
||||
total_train_batch_size = self.train_batch_size * self.gradient_accumulation_steps * self.world_size
|
||||
delattr(self, "_frozen")
|
||||
self.learning_rate = self.base_learning_rate * total_train_batch_size / 256
|
||||
setattr(self, "_frozen", True)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
@@ -362,6 +353,13 @@ def main():
|
||||
# Set the validation transforms
|
||||
ds["validation"].set_transform(preprocess_images)
|
||||
|
||||
# Compute absolute learning rate
|
||||
total_train_batch_size = (
|
||||
training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
if training_args.base_learning_rate is not None:
|
||||
training_args.learning_rate = training_args.base_learning_rate * total_train_batch_size / 256
|
||||
|
||||
# Initialize our trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
||||
@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -675,10 +674,14 @@ def main():
|
||||
return result
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
if training_args.generation_max_length is None:
|
||||
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
|
||||
if training_args.generation_num_beams is None:
|
||||
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
|
||||
training_args.generation_max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
training_args.generation_num_beams = (
|
||||
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
|
||||
@@ -21,7 +21,6 @@ https://huggingface.co/models?filter=fill-mask
|
||||
"""
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -367,7 +366,7 @@ def main():
|
||||
# If we have ref files, need to avoid it removed by trainer
|
||||
has_ref = data_args.train_ref_file or data_args.validation_ref_file
|
||||
if has_ref:
|
||||
training_args = dataclasses.replace(training_args, remove_unused_columns=False)
|
||||
training_args.remove_unused_columns = False
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
|
||||
@@ -259,6 +259,7 @@ def main():
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if training_args.output_dir is not None:
|
||||
training_args.output_dir = Path(training_args.output_dir)
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
# endregion
|
||||
|
||||
@@ -266,8 +267,8 @@ def main():
|
||||
# Detecting last checkpoint.
|
||||
checkpoint = None
|
||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
||||
config_path = training_args.output_dir / CONFIG_NAME
|
||||
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||
if config_path.is_file() and weights_path.is_file():
|
||||
checkpoint = training_args.output_dir
|
||||
logger.info(
|
||||
|
||||
@@ -265,6 +265,7 @@ def main():
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if training_args.output_dir is not None:
|
||||
training_args.output_dir = Path(training_args.output_dir)
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
|
||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
|
||||
@@ -276,8 +277,8 @@ def main():
|
||||
# Detecting last checkpoint.
|
||||
checkpoint = None
|
||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
||||
config_path = training_args.output_dir / CONFIG_NAME
|
||||
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||
if config_path.is_file() and weights_path.is_file():
|
||||
checkpoint = training_args.output_dir
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user