Raise relevent err when wrong type is passed in as the accelerator_config (#29997)
* Raise relevent err * Use type instead
This commit is contained in:
@@ -1815,6 +1815,13 @@ class TrainingArguments:
|
|||||||
self.accelerator_config = AcceleratorConfig()
|
self.accelerator_config = AcceleratorConfig()
|
||||||
elif isinstance(self.accelerator_config, dict):
|
elif isinstance(self.accelerator_config, dict):
|
||||||
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
|
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
|
||||||
|
# Check that a user didn't pass in the class instantiator
|
||||||
|
# such as `accelerator_config = AcceleratorConfig`
|
||||||
|
elif isinstance(self.accelerator_config, type):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
|
||||||
|
"Please pass in a fully constructed `AcceleratorConfig` object instead."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||||
if self.dispatch_batches is not None:
|
if self.dispatch_batches is not None:
|
||||||
|
|||||||
@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||||
|
|
||||||
|
def test_accelerator_config_not_instantiated(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
|
_ = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config=AcceleratorConfig,
|
||||||
|
)
|
||||||
|
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||||
|
|
||||||
|
# Now test with a custom subclass
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CustomAcceleratorConfig(AcceleratorConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CustomTrainingArguments(TrainingArguments):
|
||||||
|
accelerator_config: dict = dataclasses.field(
|
||||||
|
default=CustomAcceleratorConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
|
_ = CustomTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
)
|
||||||
|
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user