Update TF text classification example (#11496)
Big refactor, fixes and multi-GPU/TPU support
This commit is contained in:
@@ -54,6 +54,20 @@ After training, the model will be saved to `--output_dir`. Once your model is tr
|
|||||||
by calling the script without a `--train_file` or `--validation_file`; simply pass it the output_dir containing
|
by calling the script without a `--train_file` or `--validation_file`; simply pass it the output_dir containing
|
||||||
the trained model and a `--test_file` and it will write its predictions to a text file for you.
|
the trained model and a `--test_file` and it will write its predictions to a text file for you.
|
||||||
|
|
||||||
|
### Multi-GPU and TPU usage
|
||||||
|
|
||||||
|
By default, the script uses a `MirroredStrategy` and will use multiple GPUs effectively if they are available. TPUs
|
||||||
|
can also be used by passing the name of the TPU resource with the `--tpu` argument.
|
||||||
|
|
||||||
|
### Memory usage and data loading
|
||||||
|
|
||||||
|
One thing to note is that all data is loaded into memory in this script. Most text classification datasets are small
|
||||||
|
enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle
|
||||||
|
data streaming. This is particularly challenging for TPUs, given the stricter requirements and the sheer volume of data
|
||||||
|
required to keep them fed. A full explanation of all the possible pitfalls is a bit beyond this example script and
|
||||||
|
README, but for more information you can see the 'Input Datasets' section of
|
||||||
|
[this document](https://www.tensorflow.org/guide/tpu).
|
||||||
|
|
||||||
### Example command
|
### Example command
|
||||||
```
|
```
|
||||||
python run_text_classification.py \
|
python run_text_classification.py \
|
||||||
|
|||||||
@@ -18,10 +18,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from math import ceil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -34,7 +32,7 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TrainingArguments,
|
TFTrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
|
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
|
||||||
@@ -48,65 +46,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# region Helper classes
|
# region Helper classes
|
||||||
class DataSequence(tf.keras.utils.Sequence):
|
|
||||||
# We use a Sequence object to load the data. Although it's completely possible to load your data as Numpy/TF arrays
|
|
||||||
# and pass those straight to the Model, this constrains you in a couple of ways. Most notably, it requires all
|
|
||||||
# the data to be padded to the length of the longest input example, and it also requires the whole dataset to be
|
|
||||||
# loaded into memory. If these aren't major problems for you, you can skip the sequence object in your own code!
|
|
||||||
def __init__(self, dataset, non_label_column_names, batch_size, labels, shuffle=True):
|
|
||||||
super().__init__()
|
|
||||||
# Retain all of the columns not present in the original data - these are the ones added by the tokenizer
|
|
||||||
self.data = {
|
|
||||||
key: dataset[key]
|
|
||||||
for key in dataset.features.keys()
|
|
||||||
if key not in non_label_column_names and key != "label"
|
|
||||||
}
|
|
||||||
data_lengths = {len(array) for array in self.data.values()}
|
|
||||||
assert len(data_lengths) == 1, "Dataset arrays differ in length!"
|
|
||||||
self.data_length = data_lengths.pop()
|
|
||||||
self.num_batches = ceil(self.data_length / batch_size)
|
|
||||||
if labels:
|
|
||||||
self.labels = np.array(dataset["label"])
|
|
||||||
assert len(self.labels) == self.data_length, "Labels not the same length as input arrays!"
|
|
||||||
else:
|
|
||||||
self.labels = None
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.shuffle = shuffle
|
|
||||||
if self.shuffle:
|
|
||||||
# Shuffle the data order
|
|
||||||
self.permutation = np.random.permutation(self.data_length)
|
|
||||||
else:
|
|
||||||
self.permutation = None
|
|
||||||
|
|
||||||
def on_epoch_end(self):
|
|
||||||
# If we're shuffling, reshuffle the data order after each epoch
|
|
||||||
if self.shuffle:
|
|
||||||
self.permutation = np.random.permutation(self.data_length)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
# Note that this yields a batch, not a single sample
|
|
||||||
batch_start = item * self.batch_size
|
|
||||||
batch_end = (item + 1) * self.batch_size
|
|
||||||
if self.shuffle:
|
|
||||||
data_indices = self.permutation[batch_start:batch_end]
|
|
||||||
else:
|
|
||||||
data_indices = np.arange(batch_start, batch_end)
|
|
||||||
# We want to pad the data as little as possible, so we only pad each batch
|
|
||||||
# to the maximum length within that batch. We do that by stacking the variable-
|
|
||||||
# length inputs into a ragged tensor and then densifying it.
|
|
||||||
batch_input = {
|
|
||||||
key: tf.ragged.constant([data[i] for i in data_indices]).to_tensor() for key, data in self.data.items()
|
|
||||||
}
|
|
||||||
if self.labels is None:
|
|
||||||
return batch_input
|
|
||||||
else:
|
|
||||||
batch_labels = self.labels[data_indices]
|
|
||||||
return batch_input, batch_labels
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_batches
|
|
||||||
|
|
||||||
|
|
||||||
class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||||
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
|
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
|
||||||
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
|
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
|
||||||
@@ -119,8 +58,50 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
|||||||
self.model.save_pretrained(self.output_dir)
|
self.model.save_pretrained(self.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset_for_tensorflow(
|
||||||
|
dataset, non_label_column_names, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
|
||||||
|
):
|
||||||
|
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
|
||||||
|
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
|
||||||
|
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def densify_ragged_batch(features, label=None):
|
||||||
|
features = {
|
||||||
|
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items()
|
||||||
|
}
|
||||||
|
if label is None:
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
return features, label
|
||||||
|
|
||||||
|
feature_keys = list(set(dataset.features.keys()) - set(non_label_column_names + ["label"]))
|
||||||
|
if dataset_mode == "variable_batch":
|
||||||
|
batch_shape = {key: None for key in feature_keys}
|
||||||
|
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||||
|
elif dataset_mode == "constant_batch":
|
||||||
|
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||||
|
batch_shape = {
|
||||||
|
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
|
||||||
|
for key, ragged_tensor in data.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown dataset mode!")
|
||||||
|
|
||||||
|
if "label" in dataset.features:
|
||||||
|
labels = tf.convert_to_tensor(np.array(dataset["label"]))
|
||||||
|
tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
|
||||||
|
else:
|
||||||
|
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
|
||||||
|
if shuffle:
|
||||||
|
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
||||||
|
tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch)
|
||||||
|
return tf_dataset
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region Command-line arguments
|
# region Command-line arguments
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataTrainingArguments:
|
class DataTrainingArguments:
|
||||||
@@ -155,6 +136,7 @@ class DataTrainingArguments:
|
|||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||||
|
"Data will always be padded when using TPUs."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: Optional[int] = field(
|
||||||
@@ -164,17 +146,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_eval_samples: Optional[int] = field(
|
max_val_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_predict_samples: Optional[int] = field(
|
max_test_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -223,6 +205,7 @@ class ModelArguments:
|
|||||||
"with private models)."
|
"with private models)."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
tpu: Optional[str] = field(default=None, metadata={"help": "Name of the TPU resource to use, if available"})
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@@ -234,7 +217,7 @@ def main():
|
|||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
|
||||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
# If we pass only one argument to the script and it's the path to a json file,
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
# let's parse it to get our arguments.
|
# let's parse it to get our arguments.
|
||||||
@@ -322,12 +305,7 @@ def main():
|
|||||||
is_regression = None
|
is_regression = None
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Load pretrained model and tokenizer
|
# region Load model config and tokenizer
|
||||||
# Set seed before initializing model
|
|
||||||
set_seed(training_args.seed)
|
|
||||||
#
|
|
||||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
|
||||||
# download model & vocab.
|
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
config_path = training_args.output_dir
|
config_path = training_args.output_dir
|
||||||
elif model_args.config_name:
|
elif model_args.config_name:
|
||||||
@@ -355,34 +333,6 @@ def main():
|
|||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
if checkpoint is None:
|
|
||||||
model_path = model_args.model_name_or_path
|
|
||||||
else:
|
|
||||||
model_path = checkpoint
|
|
||||||
model = TFAutoModelForSequenceClassification.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
config=config,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
# region Optimizer, loss and compilation
|
|
||||||
optimizer = tf.keras.optimizers.Adam(
|
|
||||||
learning_rate=training_args.learning_rate,
|
|
||||||
beta_1=training_args.adam_beta1,
|
|
||||||
beta_2=training_args.adam_beta2,
|
|
||||||
epsilon=training_args.adam_epsilon,
|
|
||||||
clipnorm=training_args.max_grad_norm,
|
|
||||||
)
|
|
||||||
if is_regression:
|
|
||||||
loss = tf.keras.losses.MeanSquaredError()
|
|
||||||
metrics = []
|
|
||||||
else:
|
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
|
||||||
metrics = ["accuracy"]
|
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Dataset preprocessing
|
# region Dataset preprocessing
|
||||||
@@ -399,13 +349,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
sentence1_key, sentence2_key = non_label_column_names[0], None
|
sentence1_key, sentence2_key = non_label_column_names[0], None
|
||||||
|
|
||||||
# Padding strategy
|
|
||||||
if data_args.pad_to_max_length:
|
|
||||||
padding = "max_length"
|
|
||||||
else:
|
|
||||||
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
|
|
||||||
padding = False
|
|
||||||
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
@@ -415,8 +358,8 @@ def main():
|
|||||||
|
|
||||||
# Ensure that our labels match the model's, if it has some pre-specified
|
# Ensure that our labels match the model's, if it has some pre-specified
|
||||||
if "train" in datasets:
|
if "train" in datasets:
|
||||||
if not is_regression and model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
if not is_regression and config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||||
label_name_to_id = model.config.label2id
|
label_name_to_id = config.label2id
|
||||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||||
label_to_id = label_name_to_id # Use the model's labels
|
label_to_id = label_name_to_id # Use the model's labels
|
||||||
else:
|
else:
|
||||||
@@ -431,15 +374,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
label_to_id = None
|
label_to_id = None
|
||||||
# Now we've established our label2id, let's overwrite the model config with it.
|
# Now we've established our label2id, let's overwrite the model config with it.
|
||||||
model.config.label2id = label_to_id
|
config.label2id = label_to_id
|
||||||
if model.config.label2id is not None:
|
if config.label2id is not None:
|
||||||
model.config.id2label = {id: label for label, id in label_to_id.items()}
|
config.id2label = {id: label for label, id in label_to_id.items()}
|
||||||
else:
|
else:
|
||||||
model.config.id2label = None
|
config.id2label = None
|
||||||
else:
|
else:
|
||||||
label_to_id = model.config.label2id # Just load the data from the model
|
label_to_id = config.label2id # Just load the data from the model
|
||||||
|
|
||||||
if "validation" in datasets and model.config.label2id is not None:
|
if "validation" in datasets and config.label2id is not None:
|
||||||
validation_label_list = datasets["validation"].unique("label")
|
validation_label_list = datasets["validation"].unique("label")
|
||||||
for val_label in validation_label_list:
|
for val_label in validation_label_list:
|
||||||
assert val_label in label_to_id, f"Label {val_label} is in the validation set but not the training set!"
|
assert val_label in label_to_id, f"Label {val_label} is in the validation set but not the training set!"
|
||||||
@@ -449,87 +392,141 @@ def main():
|
|||||||
args = (
|
args = (
|
||||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||||
)
|
)
|
||||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
result = tokenizer(*args, max_length=max_seq_length, truncation=True)
|
||||||
|
|
||||||
# Map labels to IDs
|
# Map labels to IDs
|
||||||
if model.config.label2id is not None and "label" in examples:
|
if config.label2id is not None and "label" in examples:
|
||||||
result["label"] = [(model.config.label2id[l] if l != -1 else -1) for l in examples["label"]]
|
result["label"] = [(config.label2id[l] if l != -1 else -1) for l in examples["label"]]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||||
|
|
||||||
if "train" in datasets:
|
|
||||||
train_dataset = datasets["train"]
|
|
||||||
if data_args.max_train_samples is not None:
|
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
||||||
# Log a few random samples from the training set so we can see that it's working as expected:
|
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
|
||||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
|
||||||
|
|
||||||
if "validation" in datasets:
|
|
||||||
eval_dataset = datasets["validation"]
|
|
||||||
if data_args.max_eval_samples is not None:
|
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
|
||||||
|
|
||||||
if "test" in datasets:
|
|
||||||
predict_dataset = datasets["test"]
|
|
||||||
if data_args.max_predict_samples is not None:
|
|
||||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Training
|
with training_args.strategy.scope():
|
||||||
if "train" in datasets:
|
# region Load pretrained model
|
||||||
training_dataset = DataSequence(
|
# Set seed before initializing model
|
||||||
train_dataset, non_label_column_names, batch_size=training_args.per_device_train_batch_size, labels=True
|
set_seed(training_args.seed)
|
||||||
)
|
#
|
||||||
if "validation" in datasets:
|
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||||
eval_dataset = DataSequence(
|
# download model & vocab.
|
||||||
eval_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=True
|
if checkpoint is None:
|
||||||
)
|
model_path = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
eval_dataset = None
|
model_path = checkpoint
|
||||||
|
model = TFAutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
# endregion
|
||||||
|
|
||||||
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
|
# region Optimizer, loss and compilation
|
||||||
model.fit(
|
optimizer = tf.keras.optimizers.Adam(
|
||||||
training_dataset,
|
learning_rate=training_args.learning_rate,
|
||||||
validation_data=eval_dataset,
|
beta_1=training_args.adam_beta1,
|
||||||
epochs=int(training_args.num_train_epochs),
|
beta_2=training_args.adam_beta2,
|
||||||
callbacks=callbacks,
|
epsilon=training_args.adam_epsilon,
|
||||||
|
clipnorm=training_args.max_grad_norm,
|
||||||
)
|
)
|
||||||
elif "validation" in datasets:
|
|
||||||
# If there's a validation dataset but no training set, just evaluate the metrics
|
|
||||||
eval_dataset = DataSequence(
|
|
||||||
eval_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=True
|
|
||||||
)
|
|
||||||
logger.info("Computing metrics on validation data...")
|
|
||||||
if is_regression:
|
if is_regression:
|
||||||
loss = model.evaluate(eval_dataset)
|
loss_fn = tf.keras.losses.MeanSquaredError()
|
||||||
logger.info(f"Loss: {loss:.5f}")
|
metrics = []
|
||||||
else:
|
else:
|
||||||
loss, accuracy = model.evaluate(eval_dataset)
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
logger.info(f"Loss: {loss:.5f}, Accuracy: {accuracy * 100:.4f}%")
|
metrics = ["accuracy"]
|
||||||
# endregion
|
model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
|
||||||
|
# endregion
|
||||||
|
|
||||||
# region Prediction
|
# region Convert data to TF format
|
||||||
if "test" in datasets:
|
|
||||||
logger.info("Doing predictions on Predict dataset...")
|
|
||||||
|
|
||||||
predict_dataset = DataSequence(
|
# Convert data to a tf.keras.utils.Sequence object for training if we're not using a TPU
|
||||||
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
# For TPU, convert to a tf.data.Dataset
|
||||||
)
|
tf_data = dict()
|
||||||
predictions = model.predict(predict_dataset)["logits"]
|
max_samples = {
|
||||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
"train": data_args.max_train_samples,
|
||||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
"validation": data_args.max_val_samples,
|
||||||
with open(output_predict_file, "w") as writer:
|
"test": data_args.max_test_samples,
|
||||||
writer.write("index\tprediction\n")
|
}
|
||||||
for index, item in enumerate(predictions):
|
for key in ("train", "validation", "test"):
|
||||||
if is_regression:
|
if key not in datasets:
|
||||||
writer.write(f"{index}\t{item:3.3f}\n")
|
tf_data[key] = None
|
||||||
else:
|
continue
|
||||||
item = model.config.id2label[item]
|
if key in ("train", "validation"):
|
||||||
writer.write(f"{index}\t{item}\n")
|
assert "label" in datasets[key].features, f"Missing labels from {key} data!"
|
||||||
logger.info(f"Wrote predictions to {output_predict_file}!")
|
if key == "train":
|
||||||
|
shuffle = True
|
||||||
|
batch_size = training_args.per_device_train_batch_size
|
||||||
|
drop_remainder = True # Saves us worrying about scaling gradients for the last batch
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
|
batch_size = training_args.per_device_eval_batch_size
|
||||||
|
drop_remainder = False
|
||||||
|
samples_limit = max_samples[key]
|
||||||
|
dataset = datasets[key]
|
||||||
|
if samples_limit is not None:
|
||||||
|
dataset = dataset.select(range(samples_limit))
|
||||||
|
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length:
|
||||||
|
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
|
||||||
|
dataset_mode = "constant_batch"
|
||||||
|
else:
|
||||||
|
dataset_mode = "variable_batch"
|
||||||
|
data = convert_dataset_for_tensorflow(
|
||||||
|
dataset,
|
||||||
|
non_label_column_names,
|
||||||
|
batch_size=batch_size,
|
||||||
|
dataset_mode=dataset_mode,
|
||||||
|
drop_remainder=drop_remainder,
|
||||||
|
shuffle=shuffle,
|
||||||
|
)
|
||||||
|
tf_data[key] = data
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Training and validation
|
||||||
|
if tf_data["train"] is not None:
|
||||||
|
callbacks = [SavePretrainedCallback(output_dir=training_args.output_dir)]
|
||||||
|
model.fit(
|
||||||
|
tf_data["train"],
|
||||||
|
validation_data=tf_data["validation"],
|
||||||
|
epochs=int(training_args.num_train_epochs),
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
elif tf_data["validation"] is not None:
|
||||||
|
# If there's a validation dataset but no training set, just evaluate the metrics
|
||||||
|
logger.info("Computing metrics on validation data...")
|
||||||
|
if is_regression:
|
||||||
|
loss = model.evaluate(tf_data["validation"])
|
||||||
|
logger.info(f"Loss: {loss:.5f}")
|
||||||
|
else:
|
||||||
|
loss, accuracy = model.evaluate(tf_data["validation"])
|
||||||
|
logger.info(f"Loss: {loss:.5f}, Accuracy: {accuracy * 100:.4f}%")
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Prediction
|
||||||
|
if tf_data["test"] is not None:
|
||||||
|
logger.info("Doing predictions on test dataset...")
|
||||||
|
predictions = model.predict(tf_data["test"])["logits"]
|
||||||
|
predicted_class = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||||
|
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||||
|
with open(output_test_file, "w") as writer:
|
||||||
|
writer.write("index\tprediction\n")
|
||||||
|
for index, item in enumerate(predicted_class):
|
||||||
|
if is_regression:
|
||||||
|
writer.write(f"{index}\t{item:3.3f}\n")
|
||||||
|
else:
|
||||||
|
item = config.id2label[item]
|
||||||
|
writer.write(f"{index}\t{item}\n")
|
||||||
|
logger.info(f"Wrote predictions to {output_test_file}!")
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Prediction losses
|
||||||
|
# This section is outside the scope() because it's very quick to compute, but behaves badly inside it
|
||||||
|
if "label" in datasets["test"].features:
|
||||||
|
print("Computing prediction loss on test labels...")
|
||||||
|
labels = datasets["test"]["label"]
|
||||||
|
loss = float(loss_fn(labels, predictions).numpy())
|
||||||
|
print(f"Test loss: {loss:.4f}")
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,10 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
else:
|
else:
|
||||||
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
tpu = None
|
if self.tpu_name:
|
||||||
|
raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
|
||||||
|
else:
|
||||||
|
tpu = None
|
||||||
|
|
||||||
if tpu:
|
if tpu:
|
||||||
# Set to bfloat16 in case of TPU
|
# Set to bfloat16 in case of TPU
|
||||||
@@ -233,7 +236,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
strategy = tf.distribute.MirroredStrategy()
|
strategy = tf.distribute.MirroredStrategy()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot find the proper strategy please check your environment properties.")
|
raise ValueError("Cannot find the proper strategy, please check your environment properties.")
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user