Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller
2023-09-01 11:24:12 -04:00
committed by GitHub
parent 69c5b8f186
commit be0e189bd3
9 changed files with 31 additions and 58 deletions

View File

@@ -259,6 +259,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
# endregion
@@ -266,8 +267,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.info(

View File

@@ -265,6 +265,7 @@ def main():
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if training_args.output_dir is not None:
training_args.output_dir = Path(training_args.output_dir)
os.makedirs(training_args.output_dir, exist_ok=True)
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
@@ -276,8 +277,8 @@ def main():
# Detecting last checkpoint.
checkpoint = None
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
config_path = Path(training_args.output_dir) / CONFIG_NAME
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
config_path = training_args.output_dir / CONFIG_NAME
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
if config_path.is_file() and weights_path.is_file():
checkpoint = training_args.output_dir
logger.warning(