From 20932e5520fd955697743f4e2baf7879e7faa848 Mon Sep 17 00:00:00 2001 From: Kiyoung Kim Date: Wed, 27 Jan 2021 22:45:09 +0900 Subject: [PATCH] Add tpu_zone and gcp_project in training_args_tf.py (#9825) * add tpu_zone and gcp_project in training_args_tf.py * make style Co-authored-by: kykim --- datasets | 1 + src/transformers/training_args_tf.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) create mode 160000 datasets diff --git a/datasets b/datasets new file mode 160000 index 0000000000..7a6c3bae98 --- /dev/null +++ b/datasets @@ -0,0 +1 @@ +Subproject commit 7a6c3bae98d738742629ee82a48cf7b18db89154 diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index fd8187243c..bcc940164e 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -135,6 +135,12 @@ class TFTrainingArguments(TrainingArguments): at the next training step under the keyword argument ``mems``. tpu_name (:obj:`str`, `optional`): The name of the TPU the process is running on. + tpu_zone (:obj:`str`, `optional`): + The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect + from metadata. + gcp_project (:obj:`str`, `optional`): + Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to + automatically detect from metadata. run_name (:obj:`str`, `optional`): A descriptor for the run. Notably used for wandb logging. xla (:obj:`bool`, `optional`): @@ -146,6 +152,16 @@ class TFTrainingArguments(TrainingArguments): metadata={"help": "Name of TPU"}, ) + tpu_zone: str = field( + default=None, + metadata={"help": "Zone of TPU"}, + ) + + gcp_project: str = field( + default=None, + metadata={"help": "Name of Cloud TPU-enabled project"}, + ) + poly_power: float = field( default=1.0, metadata={"help": "Power for the Polynomial decay LR scheduler."}, @@ -173,7 +189,9 @@ class TFTrainingArguments(TrainingArguments): else: try: if self.tpu_name: - tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name) + tpu = tf.distribute.cluster_resolver.TPUClusterResolver( + self.tpu_name, zone=self.tpu_zone, project=self.gcp_project + ) else: tpu = tf.distribute.cluster_resolver.TPUClusterResolver() except ValueError: