Make training args fully immutable (#25435)
* Make training args fully immutable * Working tests, PyTorch * In test_trainer * during testing * Use proper dataclass way * Fix test * Another one * Fix tf * Lingering slow * Exception * Clean
This commit is contained in:
@@ -259,7 +259,6 @@ def main():
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if training_args.output_dir is not None:
|
||||
training_args.output_dir = Path(training_args.output_dir)
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
# endregion
|
||||
|
||||
@@ -267,8 +266,8 @@ def main():
|
||||
# Detecting last checkpoint.
|
||||
checkpoint = None
|
||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||
config_path = training_args.output_dir / CONFIG_NAME
|
||||
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
||||
if config_path.is_file() and weights_path.is_file():
|
||||
checkpoint = training_args.output_dir
|
||||
logger.info(
|
||||
|
||||
@@ -265,7 +265,6 @@ def main():
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if training_args.output_dir is not None:
|
||||
training_args.output_dir = Path(training_args.output_dir)
|
||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
||||
|
||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
|
||||
@@ -277,8 +276,8 @@ def main():
|
||||
# Detecting last checkpoint.
|
||||
checkpoint = None
|
||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||
config_path = training_args.output_dir / CONFIG_NAME
|
||||
weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
|
||||
config_path = Path(training_args.output_dir) / CONFIG_NAME
|
||||
weights_path = Path(training_args.output_dir) / TF2_WEIGHTS_NAME
|
||||
if config_path.is_file() and weights_path.is_file():
|
||||
checkpoint = training_args.output_dir
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user