Fix deprecated arg issue (#29372)
* Fix deprecated arg issue * Trainer check too * Check for dict or dataclass * Simplify, make config always AcceleratorConfig * Upstream to Trainer
This commit is contained in:
@@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
|
||||
def test_accelerator_config_only_deprecated_args(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
split_batches=True,
|
||||
)
|
||||
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user