Update TF LM examples (#15855)
This commit is contained in:
@@ -29,13 +29,11 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
@@ -48,6 +46,7 @@ from transformers import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
DefaultDataCollator,
|
||||
HfArgumentParser,
|
||||
TFAutoModelForCausalLM,
|
||||
TFTrainingArguments,
|
||||
@@ -160,9 +159,6 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
||||
@@ -212,20 +208,6 @@ class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Data generator
|
||||
def sample_generator(dataset, tokenizer):
|
||||
# Trim off the last partial batch if present
|
||||
sample_ordering = np.random.permutation(len(dataset))
|
||||
for sample_idx in sample_ordering:
|
||||
example = dataset[int(sample_idx)]
|
||||
# Handle dicts with proper padding and conversion to tensor.
|
||||
example = {key: tf.convert_to_tensor(arr, dtype_hint=tf.int64) for key, arr in example.items()}
|
||||
yield example, example["labels"] # TF needs some kind of labels, even if we don't use them
|
||||
return
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -457,34 +439,27 @@ def main():
|
||||
|
||||
# region TF Dataset preparation
|
||||
num_replicas = training_args.strategy.num_replicas_in_sync
|
||||
train_generator = partial(sample_generator, train_dataset, tokenizer)
|
||||
train_signature = {
|
||||
feature: tf.TensorSpec(shape=(None,), dtype=tf.int64)
|
||||
for feature in train_dataset.features
|
||||
if feature != "special_tokens_mask"
|
||||
}
|
||||
train_sig = (train_signature, train_signature["labels"])
|
||||
data_collator = DefaultDataCollator(return_tensors="tf")
|
||||
options = tf.data.Options()
|
||||
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
|
||||
tf_train_dataset = (
|
||||
tf.data.Dataset.from_generator(train_generator, output_signature=train_sig)
|
||||
.with_options(options)
|
||||
.batch(batch_size=num_replicas * training_args.per_device_train_batch_size, drop_remainder=True)
|
||||
.repeat(int(training_args.num_train_epochs))
|
||||
)
|
||||
eval_generator = partial(sample_generator, eval_dataset, tokenizer)
|
||||
eval_signature = {
|
||||
feature: tf.TensorSpec(shape=(None,), dtype=tf.int64)
|
||||
for feature in eval_dataset.features
|
||||
if feature != "special_tokens_mask"
|
||||
}
|
||||
eval_sig = (eval_signature, eval_signature["labels"])
|
||||
tf_eval_dataset = (
|
||||
tf.data.Dataset.from_generator(eval_generator, output_signature=eval_sig)
|
||||
.with_options(options)
|
||||
.batch(batch_size=num_replicas * training_args.per_device_eval_batch_size, drop_remainder=True)
|
||||
.repeat(int(training_args.num_train_epochs))
|
||||
)
|
||||
|
||||
tf_train_dataset = train_dataset.to_tf_dataset(
|
||||
# labels are passed as input, as we will use the model's internal loss
|
||||
columns=[col for col in train_dataset.features if col != "special_tokens_mask"],
|
||||
shuffle=True,
|
||||
batch_size=num_replicas * training_args.per_device_train_batch_size,
|
||||
collate_fn=data_collator,
|
||||
drop_remainder=True,
|
||||
).with_options(options)
|
||||
|
||||
tf_eval_dataset = eval_dataset.to_tf_dataset(
|
||||
# labels are passed as input, as we will use the model's internal loss
|
||||
columns=[col for col in eval_dataset.features if col != "special_tokens_mask"],
|
||||
shuffle=False,
|
||||
batch_size=num_replicas * training_args.per_device_train_batch_size,
|
||||
collate_fn=data_collator,
|
||||
drop_remainder=True,
|
||||
).with_options(options)
|
||||
# endregion
|
||||
|
||||
# region Optimizer and loss
|
||||
@@ -500,10 +475,8 @@ def main():
|
||||
weight_decay_rate=training_args.weight_decay,
|
||||
)
|
||||
|
||||
def dummy_loss(y_true, y_pred):
|
||||
return tf.reduce_mean(y_pred)
|
||||
|
||||
model.compile(optimizer=optimizer, loss={"loss": dummy_loss})
|
||||
# no user-specified loss = will use the model internal loss
|
||||
model.compile(optimizer=optimizer)
|
||||
# endregion
|
||||
|
||||
# region Training and validation
|
||||
|
||||
Reference in New Issue
Block a user