Raise relevent err when wrong type is passed in as the accelerator_config (#29997)
* Raise relevent err * Use type instead
This commit is contained in:
@@ -1815,6 +1815,13 @@ class TrainingArguments:
|
||||
self.accelerator_config = AcceleratorConfig()
|
||||
elif isinstance(self.accelerator_config, dict):
|
||||
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
|
||||
# Check that a user didn't pass in the class instantiator
|
||||
# such as `accelerator_config = AcceleratorConfig`
|
||||
elif isinstance(self.accelerator_config, type):
|
||||
raise NotImplementedError(
|
||||
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
|
||||
"Please pass in a fully constructed `AcceleratorConfig` object instead."
|
||||
)
|
||||
else:
|
||||
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||
if self.dispatch_batches is not None:
|
||||
|
||||
@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||
|
||||
def test_accelerator_config_not_instantiated(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
_ = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config=AcceleratorConfig,
|
||||
)
|
||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||
|
||||
# Now test with a custom subclass
|
||||
@dataclasses.dataclass
|
||||
class CustomAcceleratorConfig(AcceleratorConfig):
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
accelerator_config: dict = dataclasses.field(
|
||||
default=CustomAcceleratorConfig,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
_ = CustomTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
)
|
||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user