Use stable functions (#9369)
This commit is contained in:
@@ -96,15 +96,15 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
||||
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
||||
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
||||
|
||||
strategy = tf.distribute.experimental.TPUStrategy(self._setup_tpu)
|
||||
strategy = tf.distribute.TPUStrategy(self._setup_tpu)
|
||||
else:
|
||||
# currently no multi gpu is allowed
|
||||
if self.is_gpu:
|
||||
# TODO: Currently only single GPU is supported
|
||||
tf.config.experimental.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
|
||||
tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
|
||||
strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
|
||||
else:
|
||||
tf.config.experimental.set_visible_devices([], "GPU") # disable GPU
|
||||
tf.config.set_visible_devices([], "GPU") # disable GPU
|
||||
strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -27,7 +27,6 @@ from .integrations import ( # isort: split
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from packaging.version import parse
|
||||
from tensorflow.python.distribute.values import PerReplica
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
@@ -93,11 +92,6 @@ class TFTrainer:
|
||||
None,
|
||||
),
|
||||
):
|
||||
assert parse(tf.__version__).release >= (2, 2, 0), (
|
||||
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
|
||||
% tf.__version__
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.train_dataset = train_dataset
|
||||
@@ -141,7 +135,7 @@ class TFTrainer:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
|
||||
self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
|
||||
self.num_train_examples = self.train_dataset.cardinality(self.train_dataset).numpy()
|
||||
|
||||
if self.num_train_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
@@ -173,7 +167,7 @@ class TFTrainer:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
|
||||
num_examples = eval_dataset.cardinality(eval_dataset).numpy()
|
||||
|
||||
if num_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
@@ -203,7 +197,7 @@ class TFTrainer:
|
||||
Subclass and override this method if you want to inject some custom behavior.
|
||||
"""
|
||||
|
||||
num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
|
||||
num_examples = test_dataset.cardinality(test_dataset).numpy()
|
||||
|
||||
if num_examples < 0:
|
||||
raise ValueError("The training dataset must have an asserted cardinality")
|
||||
|
||||
@@ -188,7 +188,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
tf.config.experimental_connect_to_cluster(tpu)
|
||||
tf.tpu.experimental.initialize_tpu_system(tpu)
|
||||
|
||||
strategy = tf.distribute.experimental.TPUStrategy(tpu)
|
||||
strategy = tf.distribute.TPUStrategy(tpu)
|
||||
|
||||
elif len(gpus) == 0:
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
|
||||
Reference in New Issue
Block a user