Smdistributed trainer (#9798)
* Add a debug print * Adapt Trainer to use smdistributed if available * Forgotten parenthesis * Real check for sagemaker * Donforget to define device... * Woopsie, local)rank is defined differently * Update since local_rank has the proper value * Remove debug statement * More robust check for smdistributed * Quality * Deal with key not present error
This commit is contained in:
@@ -297,6 +297,20 @@ def is_pandas_available():
|
||||
return importlib.util.find_spec("pandas") is not None
|
||||
|
||||
|
||||
def is_sagemaker_distributed_available():
|
||||
# Get the sagemaker specific env variable.
|
||||
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||
try:
|
||||
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
||||
sagemaker_params = json.loads(sagemaker_params)
|
||||
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
# Lastly, check if the `smdistributed` module is present.
|
||||
return importlib.util.find_spec("smdistributed") is not None
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
|
||||
Reference in New Issue
Block a user