diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index c4183fa8f0..a99d5900b1 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -206,6 +206,7 @@ if ( PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) WEIGHTS_NAME = "pytorch_model.bin" TF2_WEIGHTS_NAME = "tf_model.h5" @@ -355,6 +356,10 @@ def is_sagemaker_distributed_available(): return importlib.util.find_spec("smdistributed") is not None +def is_training_run_on_sagemaker(): + return "SAGEMAKER_JOB_NAME" in os.environ and not DISABLE_TELEMETRY + + def is_soundfile_availble(): return _soundfile_available @@ -1165,6 +1170,32 @@ def cached_path( return output_path +def define_sagemaker_information(): + try: + instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json() + dlc_container_used = instance_data["Image"] + dlc_tag = instance_data["Image"].split(":")[1] + except Exception: + dlc_container_used = None + dlc_tag = None + + sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}")) + runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False + account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None + + sagemaker_object = { + "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None), + "sm_region": os.getenv("AWS_REGION", None), + "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0), + "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0), + "sm_distributed_training": runs_distributed_training, + "sm_deep_learning_container": dlc_container_used, + "sm_deep_learning_container_tag": dlc_tag, + "sm_account_id": account_id, + } + return sagemaker_object + + def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. @@ -1174,8 +1205,10 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ua += f"; torch/{_torch_version}" if is_tf_available(): ua += f"; tensorflow/{_tf_version}" + if is_training_run_on_sagemaker(): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items()) if isinstance(user_agent, dict): - ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua