Trigger add sm information (#10610)
* added sm to ua * update id * removed id * removed comments * added env variable * changed variable name * make quality happy * added sguggers feedback * make styling happy and remove brackets * added sm to ua * update id * removed id * removed comments * added env variable * changed variable name * make quality happy * added sguggers feedback * make styling happy and remove brackets
This commit is contained in:
@@ -206,6 +206,7 @@ if (
|
|||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||||
|
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False)
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
@@ -355,6 +356,10 @@ def is_sagemaker_distributed_available():
|
|||||||
return importlib.util.find_spec("smdistributed") is not None
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_training_run_on_sagemaker():
|
||||||
|
return "SAGEMAKER_JOB_NAME" in os.environ and not DISABLE_TELEMETRY
|
||||||
|
|
||||||
|
|
||||||
def is_soundfile_availble():
|
def is_soundfile_availble():
|
||||||
return _soundfile_available
|
return _soundfile_available
|
||||||
|
|
||||||
@@ -1165,6 +1170,32 @@ def cached_path(
|
|||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def define_sagemaker_information():
|
||||||
|
try:
|
||||||
|
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
|
||||||
|
dlc_container_used = instance_data["Image"]
|
||||||
|
dlc_tag = instance_data["Image"].split(":")[1]
|
||||||
|
except Exception:
|
||||||
|
dlc_container_used = None
|
||||||
|
dlc_tag = None
|
||||||
|
|
||||||
|
sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
|
||||||
|
runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
|
||||||
|
account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
|
||||||
|
|
||||||
|
sagemaker_object = {
|
||||||
|
"sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
|
||||||
|
"sm_region": os.getenv("AWS_REGION", None),
|
||||||
|
"sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
|
||||||
|
"sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
|
||||||
|
"sm_distributed_training": runs_distributed_training,
|
||||||
|
"sm_deep_learning_container": dlc_container_used,
|
||||||
|
"sm_deep_learning_container_tag": dlc_tag,
|
||||||
|
"sm_account_id": account_id,
|
||||||
|
}
|
||||||
|
return sagemaker_object
|
||||||
|
|
||||||
|
|
||||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Formats a user-agent string with basic info about a request.
|
Formats a user-agent string with basic info about a request.
|
||||||
@@ -1174,8 +1205,10 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
ua += f"; torch/{_torch_version}"
|
ua += f"; torch/{_torch_version}"
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
ua += f"; tensorflow/{_tf_version}"
|
ua += f"; tensorflow/{_tf_version}"
|
||||||
|
if is_training_run_on_sagemaker():
|
||||||
|
ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
|
||||||
if isinstance(user_agent, dict):
|
if isinstance(user_agent, dict):
|
||||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
||||||
elif isinstance(user_agent, str):
|
elif isinstance(user_agent, str):
|
||||||
ua += "; " + user_agent
|
ua += "; " + user_agent
|
||||||
return ua
|
return ua
|
||||||
|
|||||||
Reference in New Issue
Block a user