@@ -58,7 +58,7 @@ class MultiNodeTest(unittest.TestCase):
|
||||
def create_estimator(self, instance_count):
|
||||
job_name = f"{self.env.base_job_name}-{instance_count}-{'ddp' if 'ddp' in self.script else 'smd'}"
|
||||
# distributed data settings
|
||||
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
|
||||
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} if self.script != "run_ddp.py" else None
|
||||
|
||||
# creates estimator
|
||||
return HuggingFace(
|
||||
|
||||
Reference in New Issue
Block a user