From 4212bb0d60b63f81cc2ef44efec02b2203c65eed Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:33:36 +0700 Subject: [PATCH] [Re-submit] Compute true loss Flax examples (#19504) * Compute true loss * fixup * final * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map * Compute true loss * final * fixup * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../run_image_captioning_flax.py | 45 +++++++++---------- examples/flax/language-modeling/README.md | 2 +- .../language-modeling/run_bart_dlm_flax.py | 38 +++++++++------- .../flax/language-modeling/run_mlm_flax.py | 23 ++++++---- .../flax/language-modeling/run_t5_mlm_flax.py | 2 - .../summarization/run_summarization_flax.py | 34 +++++++++----- 6 files changed, 82 insertions(+), 62 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index 5b3fd187f0..493caa6c7b 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf batch_idx = np.arange(len(dataset)) for idx in range(steps): - start_idx = batch_size * idx end_idx = batch_size * (idx + 1) @@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"): - if train_time: summary_writer.scalar("train_time", train_time, step) @@ -782,11 +780,9 @@ def main(): num_splits = steps // steps_per_block + int(steps % steps_per_block > 0) for idx in range(num_splits): - if not block_size: _ds = ds else: - start_idx = block_size * idx end_idx = block_size * (idx + 1) @@ -926,8 +922,9 @@ def main(): # ignore padded tokens from loss loss = loss * padding_mask - loss = loss.sum() / padding_mask.sum() - return loss + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels # Define gradient update step fn def train_step(state, batch, label_smoothing_factor=0.0): @@ -936,29 +933,38 @@ def main(): def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] - loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) - return loss + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels - grad_fn = jax.value_and_grad(compute_loss) - loss, grad = grad_fn(state.params) - grad = jax.lax.pmean(grad, "batch") + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - metrics = jax.lax.pmean(metrics, axis_name="batch") - return new_state, metrics # Define eval fn def eval_step(params, batch, label_smoothing_factor=0.0): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] - loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) - # summarize metrics + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + metrics = {"loss": loss} - metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Define generation function @@ -1024,7 +1030,6 @@ def main(): ckpt_dir: str = "", is_prediction=False, ): - logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***") metrics = [] @@ -1103,12 +1108,10 @@ def main(): logger.info(desc) if jax.process_index() == 0: - if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)): os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True) if metrics: - # Save metrics (only for the evaluation/prediction being done along with training) if has_tensorboard and training_args.do_train: write_metric( @@ -1143,7 +1146,6 @@ def main(): input_rng = None if training_args.do_train: - cur_step = 0 train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) @@ -1166,7 +1168,6 @@ def main(): # train for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)): - cur_step += 1 batch = next(train_batches) batch_start = time.time() @@ -1177,7 +1178,6 @@ def main(): # log and save info if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0: - _train_metric = unreplicate(train_metric) desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |" @@ -1217,7 +1217,6 @@ def main(): # log and save info if training_args.logging_steps <= 0: - logger.info(desc) with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp: diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index 5b83ed0654..75fd64eb08 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them To setup all relevant files for training, let's create a directory. ```bash -mkdir ./norwegian-roberta-base +mkdir ./norwegian-bart-base ``` ### Train tokenizer diff --git a/examples/flax/language-modeling/run_bart_dlm_flax.py b/examples/flax/language-modeling/run_bart_dlm_flax.py index 6396f4ced9..6872e59345 100644 --- a/examples/flax/language-modeling/run_bart_dlm_flax.py +++ b/examples/flax/language-modeling/run_bart_dlm_flax.py @@ -799,19 +799,25 @@ def main(): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # take average - loss = loss.sum() / label_mask.sum() + loss = loss.sum() + num_labels = label_mask.sum() - return loss + return loss, num_labels - grad_fn = jax.value_and_grad(loss_fn) - loss, grad = grad_fn(state.params) - grad = jax.lax.pmean(grad, "batch") + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad) - metrics = jax.lax.pmean( - {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" - ) - + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics, new_dropout_rng # Create parallel version of the train step @@ -888,7 +894,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -903,9 +909,9 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.sum, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" @@ -917,7 +923,7 @@ def main(): if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: @@ -928,7 +934,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -943,9 +949,9 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics) try: perplexity = math.exp(eval_metrics["loss"]) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 5e1519bbd5..2383492aa4 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -723,18 +723,25 @@ def main(): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # take average - loss = loss.sum() / label_mask.sum() + loss = loss.sum() + num_labels = label_mask.sum() - return loss + return loss, num_labels - grad_fn = jax.value_and_grad(loss_fn) - loss, grad = grad_fn(state.params) - grad = jax.lax.pmean(grad, "batch") + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad) - metrics = jax.lax.pmean( - {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" - ) + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} return new_state, metrics, new_dropout_rng diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index fd988890d0..ceae49c6b1 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM: decoder_start_token_id: int def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding: - # convert list to dict and tensorize input batch = BatchEncoding( {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} @@ -397,7 +396,6 @@ class FlaxDataCollatorForT5MLM: return input_ids def random_spans_noise_mask(self, length): - """This function is copy of `random_spans_helper `__ . Noise mask consisting of random spans of noise tokens. diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index ed151b8bbe..fb3eb8d28c 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -784,8 +784,9 @@ def main(): # ignore padded tokens from loss loss = loss * padding_mask - loss = loss.sum() / padding_mask.sum() - return loss + loss = loss.sum() + num_labels = padding_mask.sum() + return loss, num_labels # Define gradient update step fn def train_step(state, batch, label_smoothing_factor=0.0): @@ -794,29 +795,38 @@ def main(): def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] - loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) - return loss + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss, num_labels - grad_fn = jax.value_and_grad(compute_loss) - loss, grad = grad_fn(state.params) - grad = jax.lax.pmean(grad, "batch") + grad_fn = jax.value_and_grad(compute_loss, has_aux=True) + (loss, num_labels), grad = grad_fn(state.params) + num_labels = jax.lax.psum(num_labels, "batch") + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + + # true grad = total grad / total samples + grad = jax.lax.psum(grad, "batch") + grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad) new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} - metrics = jax.lax.pmean(metrics, axis_name="batch") - return new_state, metrics # Define eval fn def eval_step(params, batch, label_smoothing_factor=0.0): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] - loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) - # summarize metrics + loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + num_labels = jax.lax.psum(num_labels, "batch") + + # true loss = total loss / total samples + loss = jax.lax.psum(loss, "batch") + loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss) + metrics = {"loss": loss} - metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Define generation function