Make output_dir Optional in TrainingArguments #27866 (#35735)

* make output_dir optional

* inintaied a basic testing module to validate and verify the changes

* Test output_dir default to 'tmp_trainer' when  unspecified.

* test existing functionality of output_dir.

* test that output dir only created when needed

* final check

* added doc string and changed the tmp_trainer to trainer_output

* amke style fixes to test file.

* another round of fixup

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
Sambhav Dixit
2025-02-11 23:24:36 +05:30
committed by GitHub
parent 03534a92f8
commit 1b57de8dcf
2 changed files with 56 additions and 3 deletions

View File

@@ -229,7 +229,7 @@ class TrainingArguments:
command line. command line.
Parameters: Parameters:
output_dir (`str`): output_dir (`str`, *optional*, defaults to `"trainer_output"`):
The output directory where the model predictions and checkpoints will be written. The output directory where the model predictions and checkpoints will be written.
overwrite_output_dir (`bool`, *optional*, defaults to `False`): overwrite_output_dir (`bool`, *optional*, defaults to `False`):
If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
@@ -814,8 +814,11 @@ class TrainingArguments:
""" """
framework = "pt" framework = "pt"
output_dir: str = field( output_dir: Optional[str] = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, default=None,
metadata={
"help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided."
},
) )
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(
default=False, default=False,
@@ -1548,6 +1551,14 @@ class TrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
# Set default output_dir if not provided
if self.output_dir is None:
self.output_dir = "trainer_output"
logger.info(
"No output directory specified, defaulting to 'trainer_output'. "
"To change this behavior, specify --output_dir when creating TrainingArguments."
)
# Parse in args that could be `dict` sent in from the CLI as a string # Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS: for field in _VALID_DICT_FIELDS:
passed_value = getattr(self, field) passed_value = getattr(self, field)

View File

@@ -0,0 +1,42 @@
import os
import tempfile
import unittest
from transformers import TrainingArguments
class TestTrainingArguments(unittest.TestCase):
def test_default_output_dir(self):
"""Test that output_dir defaults to 'tmp_trainer' when not specified."""
args = TrainingArguments(output_dir=None)
self.assertEqual(args.output_dir, "tmp_trainer")
def test_custom_output_dir(self):
"""Test that output_dir is respected when specified."""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(output_dir=tmp_dir)
self.assertEqual(args.output_dir, tmp_dir)
def test_output_dir_creation(self):
"""Test that output_dir is created only when needed."""
with tempfile.TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, "test_output")
# Directory should not exist before creating args
self.assertFalse(os.path.exists(output_dir))
# Create args with save_strategy="no" - should not create directory
args = TrainingArguments(
output_dir=output_dir,
do_train=True,
save_strategy="no",
report_to=None,
)
self.assertFalse(os.path.exists(output_dir))
# Now set save_strategy="steps" - should create directory when needed
args.save_strategy = "steps"
args.save_steps = 1
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
# Directory should be created when actually needed (e.g. in Trainer)