From 1a7c117df96adac7b60a1f6f0f228d71b1ed1283 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 1 Mar 2024 12:00:29 -0500 Subject: [PATCH] Fix deprecated arg issue (#29372) * Fix deprecated arg issue * Trainer check too * Check for dict or dataclass * Simplify, make config always AcceleratorConfig * Upstream to Trainer --- src/transformers/trainer.py | 14 +------------- src/transformers/training_args.py | 8 +++++--- tests/trainer/test_trainer.py | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b70db000c..414d97eb52 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -80,7 +80,6 @@ from .trainer_callback import ( TrainerState, ) from .trainer_pt_utils import ( - AcceleratorConfig, DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, @@ -4116,21 +4115,10 @@ class Trainer: gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) # create accelerator object - accelerator_kwargs = {} - if self.args.accelerator_config is not None: - accelerator_kwargs = self.args.accelerator_config - # dict and AcceleratorConfigs are parseable, json files are not - if isinstance(accelerator_kwargs, AcceleratorConfig): - accelerator_kwargs = accelerator_kwargs.to_dict() - elif isinstance(accelerator_kwargs, dict): - # Some values may need to go through non-accelerate aligned defaults - # and we need to run the `__post_init__` to set them - accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict() - self.accelerator = Accelerator( deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, - **accelerator_kwargs, + **self.args.accelerator_config.to_dict(), ) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 19ab24c205..ba89d914d7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1737,9 +1737,11 @@ class TrainingArguments: os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") if is_accelerate_available(): - if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)): + if not isinstance(self.accelerator_config, (AcceleratorConfig)): if self.accelerator_config is None: self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) else: self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) if self.dispatch_batches is not None: @@ -1748,7 +1750,7 @@ class TrainingArguments: " `--accelerator_config {'dispatch_batches':VALUE} instead", FutureWarning, ) - self.accelerator_config["dispatch_batches"] = self.dispatch_batches + self.accelerator_config.dispatch_batches = self.dispatch_batches if self.split_batches is not None: warnings.warn( @@ -1756,7 +1758,7 @@ class TrainingArguments: " `--accelerator_config {'split_batches':VALUE} instead", FutureWarning, ) - self.accelerator_config["split_batches"] = self.split_batches + self.accelerator_config.split_batches = self.split_batches if self.tpu_metrics_debug: warnings.warn( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 65eeb6d623..1ebbe1ca7a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.dispatch_batches, None) + def test_accelerator_config_only_deprecated_args(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertWarns(FutureWarning) as cm: + args = RegressionTrainingArguments( + output_dir=tmp_dir, + split_batches=True, + ) + self.assertIn("split_batches", str(cm.warnings[0].message)) + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + @require_torch @is_staging_test