From 1b57de8dcf2ab25c1b081c23679e2c964ce3da92 Mon Sep 17 00:00:00 2001 From: Sambhav Dixit <94298612+sambhavnoobcoder@users.noreply.github.com> Date: Tue, 11 Feb 2025 23:24:36 +0530 Subject: [PATCH] Make `output_dir` Optional in `TrainingArguments` #27866 (#35735) * make output_dir optional * inintaied a basic testing module to validate and verify the changes * Test output_dir default to 'tmp_trainer' when unspecified. * test existing functionality of output_dir. * test that output dir only created when needed * final check * added doc string and changed the tmp_trainer to trainer_output * amke style fixes to test file. * another round of fixup --------- Co-authored-by: sambhavnoobcoder --- src/transformers/training_args.py | 17 ++++++++++--- tests/test_training_args.py | 42 +++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 tests/test_training_args.py diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5bc31b6160..d1d2d2c6a3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -229,7 +229,7 @@ class TrainingArguments: command line. Parameters: - output_dir (`str`): + output_dir (`str`, *optional*, defaults to `"trainer_output"`): The output directory where the model predictions and checkpoints will be written. overwrite_output_dir (`bool`, *optional*, defaults to `False`): If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` @@ -814,8 +814,11 @@ class TrainingArguments: """ framework = "pt" - output_dir: str = field( - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + output_dir: Optional[str] = field( + default=None, + metadata={ + "help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided." + }, ) overwrite_output_dir: bool = field( default=False, @@ -1548,6 +1551,14 @@ class TrainingArguments: ) def __post_init__(self): + # Set default output_dir if not provided + if self.output_dir is None: + self.output_dir = "trainer_output" + logger.info( + "No output directory specified, defaulting to 'trainer_output'. " + "To change this behavior, specify --output_dir when creating TrainingArguments." + ) + # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: passed_value = getattr(self, field) diff --git a/tests/test_training_args.py b/tests/test_training_args.py new file mode 100644 index 0000000000..7b1daabe16 --- /dev/null +++ b/tests/test_training_args.py @@ -0,0 +1,42 @@ +import os +import tempfile +import unittest + +from transformers import TrainingArguments + + +class TestTrainingArguments(unittest.TestCase): + def test_default_output_dir(self): + """Test that output_dir defaults to 'tmp_trainer' when not specified.""" + args = TrainingArguments(output_dir=None) + self.assertEqual(args.output_dir, "tmp_trainer") + + def test_custom_output_dir(self): + """Test that output_dir is respected when specified.""" + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments(output_dir=tmp_dir) + self.assertEqual(args.output_dir, tmp_dir) + + def test_output_dir_creation(self): + """Test that output_dir is created only when needed.""" + with tempfile.TemporaryDirectory() as tmp_dir: + output_dir = os.path.join(tmp_dir, "test_output") + + # Directory should not exist before creating args + self.assertFalse(os.path.exists(output_dir)) + + # Create args with save_strategy="no" - should not create directory + args = TrainingArguments( + output_dir=output_dir, + do_train=True, + save_strategy="no", + report_to=None, + ) + self.assertFalse(os.path.exists(output_dir)) + + # Now set save_strategy="steps" - should create directory when needed + args.save_strategy = "steps" + args.save_steps = 1 + self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist + + # Directory should be created when actually needed (e.g. in Trainer)