Raise relevent err when wrong type is passed in as the accelerator_config (#29997)
* Raise relevent err * Use type instead
This commit is contained in:
@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
|
||||
|
||||
def test_accelerator_config_not_instantiated(self):
|
||||
# Checks that accelerator kwargs can be passed through
|
||||
# and the accelerator is initialized respectively
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
_ = RegressionTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config=AcceleratorConfig,
|
||||
)
|
||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||
|
||||
# Now test with a custom subclass
|
||||
@dataclasses.dataclass
|
||||
class CustomAcceleratorConfig(AcceleratorConfig):
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
accelerator_config: dict = dataclasses.field(
|
||||
default=CustomAcceleratorConfig,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
_ = CustomTrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
)
|
||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user