Fix deprecated arg issue (#29372)

* Fix deprecated arg issue

* Trainer check too

* Check for dict or dataclass

* Simplify, make config always AcceleratorConfig

* Upstream to Trainer
This commit is contained in:
Zach Mueller
2024-03-01 12:00:29 -05:00
committed by GitHub
parent cec773345a
commit 1a7c117df9
3 changed files with 20 additions and 16 deletions

View File

@@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
def test_accelerator_config_only_deprecated_args(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertWarns(FutureWarning) as cm:
args = RegressionTrainingArguments(
output_dir=tmp_dir,
split_batches=True,
)
self.assertIn("split_batches", str(cm.warnings[0].message))
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
@require_torch
@is_staging_test