TF text classification examples (#15704)
* Working example with to_tf_dataset * updated text_classification * more comments
This commit is contained in:
@@ -29,6 +29,8 @@ from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
DefaultDataCollator,
|
||||
HfArgumentParser,
|
||||
PretrainedConfig,
|
||||
TFAutoModelForSequenceClassification,
|
||||
@@ -58,47 +60,6 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
|
||||
|
||||
def convert_dataset_for_tensorflow(
|
||||
dataset, non_label_column_names, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
|
||||
):
|
||||
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
|
||||
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
|
||||
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
|
||||
"""
|
||||
|
||||
def densify_ragged_batch(features, label=None):
|
||||
features = {
|
||||
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items()
|
||||
}
|
||||
if label is None:
|
||||
return features
|
||||
else:
|
||||
return features, label
|
||||
|
||||
feature_keys = list(set(dataset.features.keys()) - set(non_label_column_names + ["label"]))
|
||||
if dataset_mode == "variable_batch":
|
||||
batch_shape = {key: None for key in feature_keys}
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||
elif dataset_mode == "constant_batch":
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys}
|
||||
batch_shape = {
|
||||
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
|
||||
for key, ragged_tensor in data.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unknown dataset mode!")
|
||||
|
||||
if "label" in dataset.features:
|
||||
labels = tf.convert_to_tensor(np.array(dataset["label"]))
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
|
||||
else:
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
|
||||
if shuffle:
|
||||
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
||||
tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch)
|
||||
return tf_dataset
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -399,6 +360,11 @@ def main():
|
||||
return result
|
||||
|
||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = DefaultDataCollator(return_tensors="tf")
|
||||
else:
|
||||
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="tf")
|
||||
# endregion
|
||||
|
||||
with training_args.strategy.scope():
|
||||
@@ -464,18 +430,14 @@ def main():
|
||||
dataset = datasets[key]
|
||||
if samples_limit is not None:
|
||||
dataset = dataset.select(range(samples_limit))
|
||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length:
|
||||
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
|
||||
dataset_mode = "constant_batch"
|
||||
else:
|
||||
dataset_mode = "variable_batch"
|
||||
data = convert_dataset_for_tensorflow(
|
||||
dataset,
|
||||
non_label_column_names,
|
||||
batch_size=batch_size,
|
||||
dataset_mode=dataset_mode,
|
||||
drop_remainder=drop_remainder,
|
||||
data = dataset.to_tf_dataset(
|
||||
columns=[col for col in dataset.column_names if col not in set(non_label_column_names + ["label"])],
|
||||
shuffle=shuffle,
|
||||
batch_size=batch_size,
|
||||
collate_fn=data_collator,
|
||||
drop_remainder=drop_remainder,
|
||||
# `label_cols` is needed for user-defined losses, such as in this example
|
||||
label_cols="label" if "label" in dataset.column_names else None,
|
||||
)
|
||||
tf_data[key] = data
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user