[Re-submit] Compute true loss Flax examples (#19504)
* Compute true loss * fixup * final * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map * Compute true loss * final * fixup * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
for idx in range(steps):
|
||||
|
||||
start_idx = batch_size * idx
|
||||
end_idx = batch_size * (idx + 1)
|
||||
|
||||
@@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||
|
||||
|
||||
def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
|
||||
|
||||
if train_time:
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
@@ -782,11 +780,9 @@ def main():
|
||||
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
||||
|
||||
for idx in range(num_splits):
|
||||
|
||||
if not block_size:
|
||||
_ds = ds
|
||||
else:
|
||||
|
||||
start_idx = block_size * idx
|
||||
end_idx = block_size * (idx + 1)
|
||||
|
||||
@@ -926,8 +922,9 @@ def main():
|
||||
|
||||
# ignore padded tokens from loss
|
||||
loss = loss * padding_mask
|
||||
loss = loss.sum() / padding_mask.sum()
|
||||
return loss
|
||||
loss = loss.sum()
|
||||
num_labels = padding_mask.sum()
|
||||
return loss, num_labels
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||
@@ -936,29 +933,38 @@ def main():
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss, num_labels
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
|
||||
# summarize metrics
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
# Define generation function
|
||||
@@ -1024,7 +1030,6 @@ def main():
|
||||
ckpt_dir: str = "",
|
||||
is_prediction=False,
|
||||
):
|
||||
|
||||
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
||||
|
||||
metrics = []
|
||||
@@ -1103,12 +1108,10 @@ def main():
|
||||
logger.info(desc)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
|
||||
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
||||
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
||||
|
||||
if metrics:
|
||||
|
||||
# Save metrics (only for the evaluation/prediction being done along with training)
|
||||
if has_tensorboard and training_args.do_train:
|
||||
write_metric(
|
||||
@@ -1143,7 +1146,6 @@ def main():
|
||||
input_rng = None
|
||||
|
||||
if training_args.do_train:
|
||||
|
||||
cur_step = 0
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
@@ -1166,7 +1168,6 @@ def main():
|
||||
|
||||
# train
|
||||
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
||||
|
||||
cur_step += 1
|
||||
batch = next(train_batches)
|
||||
batch_start = time.time()
|
||||
@@ -1177,7 +1178,6 @@ def main():
|
||||
|
||||
# log and save info
|
||||
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
||||
|
||||
_train_metric = unreplicate(train_metric)
|
||||
desc = (
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
|
||||
@@ -1217,7 +1217,6 @@ def main():
|
||||
|
||||
# log and save info
|
||||
if training_args.logging_steps <= 0:
|
||||
|
||||
logger.info(desc)
|
||||
|
||||
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
||||
|
||||
@@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them
|
||||
To setup all relevant files for training, let's create a directory.
|
||||
|
||||
```bash
|
||||
mkdir ./norwegian-roberta-base
|
||||
mkdir ./norwegian-bart-base
|
||||
```
|
||||
|
||||
### Train tokenizer
|
||||
|
||||
@@ -799,19 +799,25 @@ def main():
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# take average
|
||||
loss = loss.sum() / label_mask.sum()
|
||||
loss = loss.sum()
|
||||
num_labels = label_mask.sum()
|
||||
|
||||
return loss
|
||||
return loss, num_labels
|
||||
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
metrics = jax.lax.pmean(
|
||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||
)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
return new_state, metrics, new_dropout_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
@@ -888,7 +894,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -903,9 +909,9 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||
@@ -917,7 +923,7 @@ def main():
|
||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
@@ -928,7 +934,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -943,9 +949,9 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
try:
|
||||
perplexity = math.exp(eval_metrics["loss"])
|
||||
|
||||
@@ -723,18 +723,25 @@ def main():
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||
|
||||
# take average
|
||||
loss = loss.sum() / label_mask.sum()
|
||||
loss = loss.sum()
|
||||
num_labels = label_mask.sum()
|
||||
|
||||
return loss
|
||||
return loss, num_labels
|
||||
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
metrics = jax.lax.pmean(
|
||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||
)
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
|
||||
return new_state, metrics, new_dropout_rng
|
||||
|
||||
|
||||
@@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM:
|
||||
decoder_start_token_id: int
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
|
||||
|
||||
# convert list to dict and tensorize input
|
||||
batch = BatchEncoding(
|
||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||
@@ -397,7 +396,6 @@ class FlaxDataCollatorForT5MLM:
|
||||
return input_ids
|
||||
|
||||
def random_spans_noise_mask(self, length):
|
||||
|
||||
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
|
||||
@@ -784,8 +784,9 @@ def main():
|
||||
|
||||
# ignore padded tokens from loss
|
||||
loss = loss * padding_mask
|
||||
loss = loss.sum() / padding_mask.sum()
|
||||
return loss
|
||||
loss = loss.sum()
|
||||
num_labels = padding_mask.sum()
|
||||
return loss, num_labels
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||
@@ -794,29 +795,38 @@ def main():
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss, num_labels
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
|
||||
# summarize metrics
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
# Define generation function
|
||||
|
||||
Reference in New Issue
Block a user