Fix deprecated arg issue (#29372)
* Fix deprecated arg issue * Trainer check too * Check for dict or dataclass * Simplify, make config always AcceleratorConfig * Upstream to Trainer
This commit is contained in:
@@ -80,7 +80,6 @@ from .trainer_callback import (
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
AcceleratorConfig,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
@@ -4116,21 +4115,10 @@ class Trainer:
|
||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||
|
||||
# create accelerator object
|
||||
accelerator_kwargs = {}
|
||||
if self.args.accelerator_config is not None:
|
||||
accelerator_kwargs = self.args.accelerator_config
|
||||
# dict and AcceleratorConfigs are parseable, json files are not
|
||||
if isinstance(accelerator_kwargs, AcceleratorConfig):
|
||||
accelerator_kwargs = accelerator_kwargs.to_dict()
|
||||
elif isinstance(accelerator_kwargs, dict):
|
||||
# Some values may need to go through non-accelerate aligned defaults
|
||||
# and we need to run the `__post_init__` to set them
|
||||
accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict()
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
**accelerator_kwargs,
|
||||
**self.args.accelerator_config.to_dict(),
|
||||
)
|
||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
|
||||
@@ -1737,9 +1737,11 @@ class TrainingArguments:
|
||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
||||
|
||||
if is_accelerate_available():
|
||||
if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)):
|
||||
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
|
||||
if self.accelerator_config is None:
|
||||
self.accelerator_config = AcceleratorConfig()
|
||||
elif isinstance(self.accelerator_config, dict):
|
||||
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
|
||||
else:
|
||||
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||
if self.dispatch_batches is not None:
|
||||
@@ -1748,7 +1750,7 @@ class TrainingArguments:
|
||||
" `--accelerator_config {'dispatch_batches':VALUE} instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.accelerator_config["dispatch_batches"] = self.dispatch_batches
|
||||
self.accelerator_config.dispatch_batches = self.dispatch_batches
|
||||
|
||||
if self.split_batches is not None:
|
||||
warnings.warn(
|
||||
@@ -1756,7 +1758,7 @@ class TrainingArguments:
|
||||
" `--accelerator_config {'split_batches':VALUE} instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.accelerator_config["split_batches"] = self.split_batches
|
||||
self.accelerator_config.split_batches = self.split_batches
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
|
||||
@@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||
|
||||
def test_accelerator_config_only_deprecated_args(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertWarns(FutureWarning) as cm:
|
||||
args = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
split_batches=True,
|
||||
)
|
||||
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user