Use stable functions (#9369)

This commit is contained in:
Julien Plu
2021-01-05 09:58:26 +01:00
committed by GitHub
parent 4aa8f6ad99
commit 4225740a7b
5 changed files with 12 additions and 26 deletions

View File

@@ -96,15 +96,15 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
strategy = tf.distribute.experimental.TPUStrategy(self._setup_tpu)
strategy = tf.distribute.TPUStrategy(self._setup_tpu)
else:
# currently no multi gpu is allowed
if self.is_gpu:
# TODO: Currently only single GPU is supported
tf.config.experimental.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
else:
tf.config.experimental.set_visible_devices([], "GPU") # disable GPU
tf.config.set_visible_devices([], "GPU") # disable GPU
strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
return strategy

View File

@@ -27,7 +27,6 @@ from .integrations import ( # isort: split
import numpy as np
import tensorflow as tf
from packaging.version import parse
from tensorflow.python.distribute.values import PerReplica
from .modeling_tf_utils import TFPreTrainedModel
@@ -93,11 +92,6 @@ class TFTrainer:
None,
),
):
assert parse(tf.__version__).release >= (2, 2, 0), (
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
% tf.__version__
)
self.model = model
self.args = args
self.train_dataset = train_dataset
@@ -141,7 +135,7 @@ class TFTrainer:
raise ValueError("Trainer: training requires a train_dataset.")
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy()
if self.num_train_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
@@ -173,7 +167,7 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
num_examples = eval_dataset.cardinality(eval_dataset).numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")
@@ -203,7 +197,7 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior.
"""
num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
num_examples = test_dataset.cardinality(test_dataset).numpy()
if num_examples < 0:
raise ValueError("The training dataset must have an asserted cardinality")

View File

@@ -188,7 +188,7 @@ class TFTrainingArguments(TrainingArguments):
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
elif len(gpus) == 0:
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")