From 7eadfe166e29df0916fe4b4a2c9b56fe1f071936 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 29 Jan 2021 09:52:26 -0500 Subject: [PATCH] When on sagemaker use their env variables for saves (#9876) * When on sagemaker use their env variables for saves * Address review comments * Quality --- src/transformers/trainer.py | 5 +++++ src/transformers/training_args.py | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8ca84beefe..c25c7cb42d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1366,6 +1366,11 @@ class Trainer: elif self.is_world_process_zero(): self._save(output_dir) + # If on sagemaker and we are saving the main model (not a checkpoint so output_dir=None), save a copy to + # SM_MODEL_DIR for easy deployment. + if output_dir is None and os.getenv("SM_MODEL_DIR") is not None: + self.save_model(output_dir=os.getenv("SM_MODEL_DIR")) + def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info("Saving model checkpoint to %s", output_dir) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 71d5255944..b426f2c430 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -248,8 +248,9 @@ class TrainingArguments: Whether you want to pin memory in data loaders or not. Will default to :obj:`True`. """ - output_dir: str = field( - metadata={"help": "The output directory where the model predictions and checkpoints will be written."} + output_dir: Optional[str] = field( + default=None, + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) overwrite_output_dir: bool = field( default=False, @@ -444,6 +445,18 @@ class TrainingArguments: _n_gpu: int = field(init=False, repr=False, default=-1) def __post_init__(self): + if self.output_dir is None and os.getenv("SM_OUTPUT_DATA_DIR") is None: + raise ValueError( + "`output_dir` is only optional if it can get inferred from the environment. Please set a value for " + "`output_dir`." + ) + elif os.getenv("SM_OUTPUT_DATA_DIR") is not None: + if self.output_dir is not None: + logger.warn( + "`output_dir` is overwritten by the env variable 'SM_OUTPUT_DATA_DIR' " + f"({os.getenv('SM_OUTPUT_DATA_DIR')})." + ) + self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR") if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)