* make output_dir optional * inintaied a basic testing module to validate and verify the changes * Test output_dir default to 'tmp_trainer' when unspecified. * test existing functionality of output_dir. * test that output dir only created when needed * final check * added doc string and changed the tmp_trainer to trainer_output * amke style fixes to test file. * another round of fixup --------- Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
@@ -229,7 +229,7 @@ class TrainingArguments:
|
|||||||
command line.
|
command line.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
output_dir (`str`):
|
output_dir (`str`, *optional*, defaults to `"trainer_output"`):
|
||||||
The output directory where the model predictions and checkpoints will be written.
|
The output directory where the model predictions and checkpoints will be written.
|
||||||
overwrite_output_dir (`bool`, *optional*, defaults to `False`):
|
overwrite_output_dir (`bool`, *optional*, defaults to `False`):
|
||||||
If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
|
If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
|
||||||
@@ -814,8 +814,11 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
framework = "pt"
|
framework = "pt"
|
||||||
output_dir: str = field(
|
output_dir: Optional[str] = field(
|
||||||
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -1548,6 +1551,14 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Set default output_dir if not provided
|
||||||
|
if self.output_dir is None:
|
||||||
|
self.output_dir = "trainer_output"
|
||||||
|
logger.info(
|
||||||
|
"No output directory specified, defaulting to 'trainer_output'. "
|
||||||
|
"To change this behavior, specify --output_dir when creating TrainingArguments."
|
||||||
|
)
|
||||||
|
|
||||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||||
for field in _VALID_DICT_FIELDS:
|
for field in _VALID_DICT_FIELDS:
|
||||||
passed_value = getattr(self, field)
|
passed_value = getattr(self, field)
|
||||||
|
|||||||
42
tests/test_training_args.py
Normal file
42
tests/test_training_args.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingArguments(unittest.TestCase):
|
||||||
|
def test_default_output_dir(self):
|
||||||
|
"""Test that output_dir defaults to 'tmp_trainer' when not specified."""
|
||||||
|
args = TrainingArguments(output_dir=None)
|
||||||
|
self.assertEqual(args.output_dir, "tmp_trainer")
|
||||||
|
|
||||||
|
def test_custom_output_dir(self):
|
||||||
|
"""Test that output_dir is respected when specified."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = TrainingArguments(output_dir=tmp_dir)
|
||||||
|
self.assertEqual(args.output_dir, tmp_dir)
|
||||||
|
|
||||||
|
def test_output_dir_creation(self):
|
||||||
|
"""Test that output_dir is created only when needed."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
output_dir = os.path.join(tmp_dir, "test_output")
|
||||||
|
|
||||||
|
# Directory should not exist before creating args
|
||||||
|
self.assertFalse(os.path.exists(output_dir))
|
||||||
|
|
||||||
|
# Create args with save_strategy="no" - should not create directory
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
do_train=True,
|
||||||
|
save_strategy="no",
|
||||||
|
report_to=None,
|
||||||
|
)
|
||||||
|
self.assertFalse(os.path.exists(output_dir))
|
||||||
|
|
||||||
|
# Now set save_strategy="steps" - should create directory when needed
|
||||||
|
args.save_strategy = "steps"
|
||||||
|
args.save_steps = 1
|
||||||
|
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
|
||||||
|
|
||||||
|
# Directory should be created when actually needed (e.g. in Trainer)
|
||||||
Reference in New Issue
Block a user