Fix deprecated arg issue (#29372)
* Fix deprecated arg issue * Trainer check too * Check for dict or dataclass * Simplify, make config always AcceleratorConfig * Upstream to Trainer
This commit is contained in:
@@ -80,7 +80,6 @@ from .trainer_callback import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
AcceleratorConfig,
|
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
IterableDatasetShard,
|
IterableDatasetShard,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
@@ -4116,21 +4115,10 @@ class Trainer:
|
|||||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||||
|
|
||||||
# create accelerator object
|
# create accelerator object
|
||||||
accelerator_kwargs = {}
|
|
||||||
if self.args.accelerator_config is not None:
|
|
||||||
accelerator_kwargs = self.args.accelerator_config
|
|
||||||
# dict and AcceleratorConfigs are parseable, json files are not
|
|
||||||
if isinstance(accelerator_kwargs, AcceleratorConfig):
|
|
||||||
accelerator_kwargs = accelerator_kwargs.to_dict()
|
|
||||||
elif isinstance(accelerator_kwargs, dict):
|
|
||||||
# Some values may need to go through non-accelerate aligned defaults
|
|
||||||
# and we need to run the `__post_init__` to set them
|
|
||||||
accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict()
|
|
||||||
|
|
||||||
self.accelerator = Accelerator(
|
self.accelerator = Accelerator(
|
||||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||||
**accelerator_kwargs,
|
**self.args.accelerator_config.to_dict(),
|
||||||
)
|
)
|
||||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||||
self.gather_function = self.accelerator.gather_for_metrics
|
self.gather_function = self.accelerator.gather_for_metrics
|
||||||
|
|||||||
@@ -1737,9 +1737,11 @@ class TrainingArguments:
|
|||||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)):
|
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
|
||||||
if self.accelerator_config is None:
|
if self.accelerator_config is None:
|
||||||
self.accelerator_config = AcceleratorConfig()
|
self.accelerator_config = AcceleratorConfig()
|
||||||
|
elif isinstance(self.accelerator_config, dict):
|
||||||
|
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
|
||||||
else:
|
else:
|
||||||
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||||
if self.dispatch_batches is not None:
|
if self.dispatch_batches is not None:
|
||||||
@@ -1748,7 +1750,7 @@ class TrainingArguments:
|
|||||||
" `--accelerator_config {'dispatch_batches':VALUE} instead",
|
" `--accelerator_config {'dispatch_batches':VALUE} instead",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
self.accelerator_config["dispatch_batches"] = self.dispatch_batches
|
self.accelerator_config.dispatch_batches = self.dispatch_batches
|
||||||
|
|
||||||
if self.split_batches is not None:
|
if self.split_batches is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -1756,7 +1758,7 @@ class TrainingArguments:
|
|||||||
" `--accelerator_config {'split_batches':VALUE} instead",
|
" `--accelerator_config {'split_batches':VALUE} instead",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
self.accelerator_config["split_batches"] = self.split_batches
|
self.accelerator_config.split_batches = self.split_batches
|
||||||
|
|
||||||
if self.tpu_metrics_debug:
|
if self.tpu_metrics_debug:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.accelerator.even_batches, False)
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||||
|
|
||||||
|
def test_accelerator_config_only_deprecated_args(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertWarns(FutureWarning) as cm:
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
split_batches=True,
|
||||||
|
)
|
||||||
|
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user