Trigger add sm information (#10610)
* added sm to ua * update id * removed id * removed comments * added env variable * changed variable name * make quality happy * added sguggers feedback * make styling happy and remove brackets * added sm to ua * update id * removed id * removed comments * added env variable * changed variable name * make quality happy * added sguggers feedback * make styling happy and remove brackets
This commit is contained in:
@@ -206,6 +206,7 @@ if (
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False)
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
@@ -355,6 +356,10 @@ def is_sagemaker_distributed_available():
|
||||
return importlib.util.find_spec("smdistributed") is not None
|
||||
|
||||
|
||||
def is_training_run_on_sagemaker():
|
||||
return "SAGEMAKER_JOB_NAME" in os.environ and not DISABLE_TELEMETRY
|
||||
|
||||
|
||||
def is_soundfile_availble():
|
||||
return _soundfile_available
|
||||
|
||||
@@ -1165,6 +1170,32 @@ def cached_path(
|
||||
return output_path
|
||||
|
||||
|
||||
def define_sagemaker_information():
|
||||
try:
|
||||
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
|
||||
dlc_container_used = instance_data["Image"]
|
||||
dlc_tag = instance_data["Image"].split(":")[1]
|
||||
except Exception:
|
||||
dlc_container_used = None
|
||||
dlc_tag = None
|
||||
|
||||
sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
|
||||
runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
|
||||
account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
|
||||
|
||||
sagemaker_object = {
|
||||
"sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
|
||||
"sm_region": os.getenv("AWS_REGION", None),
|
||||
"sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
|
||||
"sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
|
||||
"sm_distributed_training": runs_distributed_training,
|
||||
"sm_deep_learning_container": dlc_container_used,
|
||||
"sm_deep_learning_container_tag": dlc_tag,
|
||||
"sm_account_id": account_id,
|
||||
}
|
||||
return sagemaker_object
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
@@ -1174,8 +1205,10 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
ua += f"; torch/{_torch_version}"
|
||||
if is_tf_available():
|
||||
ua += f"; tensorflow/{_tf_version}"
|
||||
if is_training_run_on_sagemaker():
|
||||
ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
return ua
|
||||
|
||||
Reference in New Issue
Block a user