Update TF text classification example (#11496)
Big refactor, fixes and multi-GPU/TPU support
This commit is contained in:
@@ -54,6 +54,20 @@ After training, the model will be saved to `--output_dir`. Once your model is tr
|
||||
by calling the script without a `--train_file` or `--validation_file`; simply pass it the output_dir containing
|
||||
the trained model and a `--test_file` and it will write its predictions to a text file for you.
|
||||
|
||||
### Multi-GPU and TPU usage
|
||||
|
||||
By default, the script uses a `MirroredStrategy` and will use multiple GPUs effectively if they are available. TPUs
|
||||
can also be used by passing the name of the TPU resource with the `--tpu` argument.
|
||||
|
||||
### Memory usage and data loading
|
||||
|
||||
One thing to note is that all data is loaded into memory in this script. Most text classification datasets are small
|
||||
enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle
|
||||
data streaming. This is particularly challenging for TPUs, given the stricter requirements and the sheer volume of data
|
||||
required to keep them fed. A full explanation of all the possible pitfalls is a bit beyond this example script and
|
||||
README, but for more information you can see the 'Input Datasets' section of
|
||||
[this document](https://www.tensorflow.org/guide/tpu).
|
||||
|
||||
### Example command
|
||||
```
|
||||
python run_text_classification.py \
|
||||
|
||||
@@ -18,10 +18,8 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -34,7 +32,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
PretrainedConfig,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TrainingArguments,
|
||||
TFTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
|
||||
@@ -48,65 +46,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region Helper classes
|
||||
class DataSequence(tf.keras.utils.Sequence):
|
||||
# We use a Sequence object to load the data. Although it's completely possible to load your data as Numpy/TF arrays
|
||||
# and pass those straight to the Model, this constrains you in a couple of ways. Most notably, it requires all
|
||||
# the data to be padded to the length of the longest input example, and it also requires the whole dataset to be
|
||||
# loaded into memory. If these aren't major problems for you, you can skip the sequence object in your own code!
|
||||
def __init__(self, dataset, non_label_column_names, batch_size, labels, shuffle=True):
|
||||
super().__init__()
|
||||
# Retain all of the columns not present in the original data - these are the ones added by the tokenizer
|
||||
self.data = {
|
||||
key: dataset[key]
|
||||
for key in dataset.features.keys()
|
||||
if key not in non_label_column_names and key != "label"
|
||||
}
|
||||
data_lengths = {len(array) for array in self.data.values()}
|
||||
assert len(data_lengths) == 1, "Dataset arrays differ in length!"
|
||||
self.data_length = data_lengths.pop()
|
||||
self.num_batches = ceil(self.data_length / batch_size)
|
||||
if labels:
|
||||
self.labels = np.array(dataset["label"])
|
||||
assert len(self.labels) == self.data_length, "Labels not the same length as input arrays!"
|
||||
else:
|
||||
self.labels = None
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if self.shuffle:
|
||||
# Shuffle the data order
|
||||
self.permutation = np.random.permutation(self.data_length)
|
||||
else:
|
||||
self.permutation = None
|
||||
|
||||
def on_epoch_end(self):
|
||||
# If we're shuffling, reshuffle the data order after each epoch
|
||||
if self.shuffle:
|
||||
self.permutation = np.random.permutation(self.data_length)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# Note that this yields a batch, not a single sample
|
||||
batch_start = item * self.batch_size
|
||||
batch_end = (item + 1) * self.batch_size
|
||||
if self.shuffle:
|
||||
data_indices = self.permutation[batch_start:batch_end]
|
||||
else:
|
||||
data_indices = np.arange(batch_start, batch_end)
|
||||
# We want to pad the data as little as possible, so we only pad each batch
|
||||
# to the maximum length within that batch. We do that by stacking the variable-
|
||||
# length inputs into a ragged tensor and then densifying it.
|
||||
batch_input = {
|
||||
key: tf.ragged.constant([data[i] for i in data_indices]).to_tensor() for key, data in self.data.items()
|
||||
}
|
||||
if self.labels is None:
|
||||
return batch_input
|
||||
else:
|
||||
batch_labels = self.labels[data_indices]
|
||||
return batch_input, batch_labels
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches
|
||||
|
||||
|
||||
class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
|
||||
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
|
||||
@@ -119,8 +58,50 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
|
||||
|
||||
def convert_dataset_for_tensorflow(
|
||||
dataset, non_label_column_names, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
|
||||
):
|
||||
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
|
||||
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
|
||||
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
|
||||
"""
|
||||
|
||||
def densify_ragged_batch(features, label=None):
|
||||
features = {
|
||||
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items()
|
||||
}
|
||||
if label is None:
|
||||
return features
|
||||
else:
|
||||
return features, label
|
||||
|
||||
feature_keys = list(set(dataset.features.keys()) - set(non_label_column_names + ["label"]))
|
||||
if dataset_mode == "variable_batch":
|
||||
batch_shape = {key: None for key in feature_keys}
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||
elif dataset_mode == "constant_batch":
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||
batch_shape = {
|
||||
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
|
||||
for key, ragged_tensor in data.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unknown dataset mode!")
|
||||
|
||||
if "label" in dataset.features:
|
||||
labels = tf.convert_to_tensor(np.array(dataset["label"]))
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
|
||||
else:
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
|
||||
if shuffle:
|
||||
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
||||
tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch)
|
||||
return tf_dataset
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Command-line arguments
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
@@ -155,6 +136,7 @@ class DataTrainingArguments:
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
"Data will always be padded when using TPUs."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
@@ -164,17 +146,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -223,6 +205,7 @@ class ModelArguments:
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
tpu: Optional[str] = field(default=None, metadata={"help": "Name of the TPU resource to use, if available"})
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -234,7 +217,7 @@ def main():
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
@@ -322,12 +305,7 @@ def main():
|
||||
is_regression = None
|
||||
# endregion
|
||||
|
||||
# region Load pretrained model and tokenizer
|
||||
# Set seed before initializing model
|
||||
set_seed(training_args.seed)
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
# region Load model config and tokenizer
|
||||
if checkpoint is not None:
|
||||
config_path = training_args.output_dir
|
||||
elif model_args.config_name:
|
||||
@@ -355,34 +333,6 @@ def main():
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
if checkpoint is None:
|
||||
model_path = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = checkpoint
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region Optimizer, loss and compilation
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
beta_1=training_args.adam_beta1,
|
||||
beta_2=training_args.adam_beta2,
|
||||
epsilon=training_args.adam_epsilon,
|
||||
clipnorm=training_args.max_grad_norm,
|
||||
)
|
||||
if is_regression:
|
||||
loss = tf.keras.losses.MeanSquaredError()
|
||||
metrics = []
|
||||
else:
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metrics = ["accuracy"]
|
||||
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
||||
# endregion
|
||||
|
||||
# region Dataset preprocessing
|
||||
@@ -399,13 +349,6 @@ def main():
|
||||
else:
|
||||
sentence1_key, sentence2_key = non_label_column_names[0], None
|
||||
|
||||
# Padding strategy
|
||||
if data_args.pad_to_max_length:
|
||||
padding = "max_length"
|
||||
else:
|
||||
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
|
||||
padding = False
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
@@ -415,8 +358,8 @@ def main():
|
||||
|
||||
# Ensure that our labels match the model's, if it has some pre-specified
|
||||
if "train" in datasets:
|
||||
if not is_regression and model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = model.config.label2id
|
||||
if not is_regression and config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = config.label2id
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
label_to_id = label_name_to_id # Use the model's labels
|
||||
else:
|
||||
@@ -431,15 +374,15 @@ def main():
|
||||
else:
|
||||
label_to_id = None
|
||||
# Now we've established our label2id, let's overwrite the model config with it.
|
||||
model.config.label2id = label_to_id
|
||||
if model.config.label2id is not None:
|
||||
model.config.id2label = {id: label for label, id in label_to_id.items()}
|
||||
config.label2id = label_to_id
|
||||
if config.label2id is not None:
|
||||
config.id2label = {id: label for label, id in label_to_id.items()}
|
||||
else:
|
||||
model.config.id2label = None
|
||||
config.id2label = None
|
||||
else:
|
||||
label_to_id = model.config.label2id # Just load the data from the model
|
||||
label_to_id = config.label2id # Just load the data from the model
|
||||
|
||||
if "validation" in datasets and model.config.label2id is not None:
|
||||
if "validation" in datasets and config.label2id is not None:
|
||||
validation_label_list = datasets["validation"].unique("label")
|
||||
for val_label in validation_label_list:
|
||||
assert val_label in label_to_id, f"Label {val_label} is in the validation set but not the training set!"
|
||||
@@ -449,87 +392,141 @@ def main():
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||
result = tokenizer(*args, max_length=max_seq_length, truncation=True)
|
||||
|
||||
# Map labels to IDs
|
||||
if model.config.label2id is not None and "label" in examples:
|
||||
result["label"] = [(model.config.label2id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
if config.label2id is not None and "label" in examples:
|
||||
result["label"] = [(config.label2id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
return result
|
||||
|
||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||
|
||||
if "train" in datasets:
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
# Log a few random samples from the training set so we can see that it's working as expected:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
if "validation" in datasets:
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if "test" in datasets:
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Training
|
||||
if "train" in datasets:
|
||||
training_dataset = DataSequence(
|
||||
train_dataset, non_label_column_names, batch_size=training_args.per_device_train_batch_size, labels=True
|
||||
)
|
||||
if "validation" in datasets:
|
||||
eval_dataset = DataSequence(
|
||||
eval_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=True
|
||||
)
|
||||
with training_args.strategy.scope():
|
||||
# region Load pretrained model
|
||||
# Set seed before initializing model
|
||||
set_seed(training_args.seed)
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
if checkpoint is None:
|
||||
model_path = model_args.model_name_or_path
|
||||
else:
|
||||
eval_dataset = None
|
||||
model_path = checkpoint
|
||||
model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# endregion
|
||||
|
||||
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
|
||||
model.fit(
|
||||
training_dataset,
|
||||
validation_data=eval_dataset,
|
||||
epochs=int(training_args.num_train_epochs),
|
||||
callbacks=callbacks,
|
||||
# region Optimizer, loss and compilation
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
beta_1=training_args.adam_beta1,
|
||||
beta_2=training_args.adam_beta2,
|
||||
epsilon=training_args.adam_epsilon,
|
||||
clipnorm=training_args.max_grad_norm,
|
||||
)
|
||||
elif "validation" in datasets:
|
||||
# If there's a validation dataset but no training set, just evaluate the metrics
|
||||
eval_dataset = DataSequence(
|
||||
eval_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=True
|
||||
)
|
||||
logger.info("Computing metrics on validation data...")
|
||||
if is_regression:
|
||||
loss = model.evaluate(eval_dataset)
|
||||
logger.info(f"Loss: {loss:.5f}")
|
||||
loss_fn = tf.keras.losses.MeanSquaredError()
|
||||
metrics = []
|
||||
else:
|
||||
loss, accuracy = model.evaluate(eval_dataset)
|
||||
logger.info(f"Loss: {loss:.5f}, Accuracy: {accuracy * 100:.4f}%")
|
||||
# endregion
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metrics = ["accuracy"]
|
||||
model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
|
||||
# endregion
|
||||
|
||||
# region Prediction
|
||||
if "test" in datasets:
|
||||
logger.info("Doing predictions on Predict dataset...")
|
||||
# region Convert data to TF format
|
||||
|
||||
predict_dataset = DataSequence(
|
||||
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||
)
|
||||
predictions = model.predict(predict_dataset)["logits"]
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
||||
with open(output_predict_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
if is_regression:
|
||||
writer.write(f"{index}\t{item:3.3f}\n")
|
||||
else:
|
||||
item = model.config.id2label[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
logger.info(f"Wrote predictions to {output_predict_file}!")
|
||||
# Convert data to a tf.keras.utils.Sequence object for training if we're not using a TPU
|
||||
# For TPU, convert to a tf.data.Dataset
|
||||
tf_data = dict()
|
||||
max_samples = {
|
||||
"train": data_args.max_train_samples,
|
||||
"validation": data_args.max_val_samples,
|
||||
"test": data_args.max_test_samples,
|
||||
}
|
||||
for key in ("train", "validation", "test"):
|
||||
if key not in datasets:
|
||||
tf_data[key] = None
|
||||
continue
|
||||
if key in ("train", "validation"):
|
||||
assert "label" in datasets[key].features, f"Missing labels from {key} data!"
|
||||
if key == "train":
|
||||
shuffle = True
|
||||
batch_size = training_args.per_device_train_batch_size
|
||||
drop_remainder = True # Saves us worrying about scaling gradients for the last batch
|
||||
else:
|
||||
shuffle = False
|
||||
batch_size = training_args.per_device_eval_batch_size
|
||||
drop_remainder = False
|
||||
samples_limit = max_samples[key]
|
||||
dataset = datasets[key]
|
||||
if samples_limit is not None:
|
||||
dataset = dataset.select(range(samples_limit))
|
||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length:
|
||||
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
|
||||
dataset_mode = "constant_batch"
|
||||
else:
|
||||
dataset_mode = "variable_batch"
|
||||
data = convert_dataset_for_tensorflow(
|
||||
dataset,
|
||||
non_label_column_names,
|
||||
batch_size=batch_size,
|
||||
dataset_mode=dataset_mode,
|
||||
drop_remainder=drop_remainder,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
tf_data[key] = data
|
||||
# endregion
|
||||
|
||||
# region Training and validation
|
||||
if tf_data["train"] is not None:
|
||||
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
|
||||
model.fit(
|
||||
tf_data["train"],
|
||||
validation_data=tf_data["validation"],
|
||||
epochs=int(training_args.num_train_epochs),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
elif tf_data["validation"] is not None:
|
||||
# If there's a validation dataset but no training set, just evaluate the metrics
|
||||
logger.info("Computing metrics on validation data...")
|
||||
if is_regression:
|
||||
loss = model.evaluate(tf_data["validation"])
|
||||
logger.info(f"Loss: {loss:.5f}")
|
||||
else:
|
||||
loss, accuracy = model.evaluate(tf_data["validation"])
|
||||
logger.info(f"Loss: {loss:.5f}, Accuracy: {accuracy * 100:.4f}%")
|
||||
# endregion
|
||||
|
||||
# region Prediction
|
||||
if tf_data["test"] is not None:
|
||||
logger.info("Doing predictions on test dataset...")
|
||||
predictions = model.predict(tf_data["test"])["logits"]
|
||||
predicted_class = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
with open(output_test_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predicted_class):
|
||||
if is_regression:
|
||||
writer.write(f"{index}\t{item:3.3f}\n")
|
||||
else:
|
||||
item = config.id2label[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
logger.info(f"Wrote predictions to {output_test_file}!")
|
||||
# endregion
|
||||
|
||||
# region Prediction losses
|
||||
# This section is outside the scope() because it's very quick to compute, but behaves badly inside it
|
||||
if "label" in datasets["test"].features:
|
||||
print("Computing prediction loss on test labels...")
|
||||
labels = datasets["test"]["label"]
|
||||
loss = float(loss_fn(labels, predictions).numpy())
|
||||
print(f"Test loss: {loss:.4f}")
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
@@ -212,7 +212,10 @@ class TFTrainingArguments(TrainingArguments):
|
||||
else:
|
||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||
except ValueError:
|
||||
tpu = None
|
||||
if self.tpu_name:
|
||||
raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
|
||||
else:
|
||||
tpu = None
|
||||
|
||||
if tpu:
|
||||
# Set to bfloat16 in case of TPU
|
||||
@@ -233,7 +236,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
else:
|
||||
raise ValueError("Cannot find the proper strategy please check your environment properties.")
|
||||
raise ValueError("Cannot find the proper strategy, please check your environment properties.")
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
Reference in New Issue
Block a user