Smdistributed trainer (#9798)

* Add a debug print

* Adapt Trainer to use smdistributed if available

* Forgotten parenthesis

* Real check for sagemaker

* Donforget to define device...

* Woopsie, local)rank is defined differently

* Update since local_rank has the proper value

* Remove debug statement

* More robust check for smdistributed

* Quality

* Deal with key not present error
This commit is contained in:
Sylvain Gugger
2021-01-26 10:28:21 -05:00
committed by GitHub
parent 897a24c869
commit 0d0efd3a0e
5 changed files with 85 additions and 29 deletions

View File

@@ -297,6 +297,20 @@ def is_pandas_available():
return importlib.util.find_spec("pandas") is not None
def is_sagemaker_distributed_available():
# Get the sagemaker specific env variable.
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
try:
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
sagemaker_params = json.loads(sagemaker_params)
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
return False
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available: