Add tpu_zone and gcp_project in training_args_tf.py (#9825)
* add tpu_zone and gcp_project in training_args_tf.py * make style Co-authored-by: kykim <kykim>
This commit is contained in:
@@ -135,6 +135,12 @@ class TFTrainingArguments(TrainingArguments):
|
||||
at the next training step under the keyword argument ``mems``.
|
||||
tpu_name (:obj:`str`, `optional`):
|
||||
The name of the TPU the process is running on.
|
||||
tpu_zone (:obj:`str`, `optional`):
|
||||
The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
|
||||
from metadata.
|
||||
gcp_project (:obj:`str`, `optional`):
|
||||
Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
|
||||
automatically detect from metadata.
|
||||
run_name (:obj:`str`, `optional`):
|
||||
A descriptor for the run. Notably used for wandb logging.
|
||||
xla (:obj:`bool`, `optional`):
|
||||
@@ -146,6 +152,16 @@ class TFTrainingArguments(TrainingArguments):
|
||||
metadata={"help": "Name of TPU"},
|
||||
)
|
||||
|
||||
tpu_zone: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Zone of TPU"},
|
||||
)
|
||||
|
||||
gcp_project: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of Cloud TPU-enabled project"},
|
||||
)
|
||||
|
||||
poly_power: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Power for the Polynomial decay LR scheduler."},
|
||||
@@ -173,7 +189,9 @@ class TFTrainingArguments(TrainingArguments):
|
||||
else:
|
||||
try:
|
||||
if self.tpu_name:
|
||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
|
||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
|
||||
self.tpu_name, zone=self.tpu_zone, project=self.gcp_project
|
||||
)
|
||||
else:
|
||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
except ValueError:
|
||||
|
||||
Reference in New Issue
Block a user