[Re-submit] Compute true loss Flax examples (#19504)
* Compute true loss * fixup * final * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map * Compute true loss * final * fixup * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
for idx in range(steps):
|
||||
|
||||
start_idx = batch_size * idx
|
||||
end_idx = batch_size * (idx + 1)
|
||||
|
||||
@@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||
|
||||
|
||||
def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
|
||||
|
||||
if train_time:
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
@@ -782,11 +780,9 @@ def main():
|
||||
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
||||
|
||||
for idx in range(num_splits):
|
||||
|
||||
if not block_size:
|
||||
_ds = ds
|
||||
else:
|
||||
|
||||
start_idx = block_size * idx
|
||||
end_idx = block_size * (idx + 1)
|
||||
|
||||
@@ -926,8 +922,9 @@ def main():
|
||||
|
||||
# ignore padded tokens from loss
|
||||
loss = loss * padding_mask
|
||||
loss = loss.sum() / padding_mask.sum()
|
||||
return loss
|
||||
loss = loss.sum()
|
||||
num_labels = padding_mask.sum()
|
||||
return loss, num_labels
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||
@@ -936,29 +933,38 @@ def main():
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
return loss, num_labels
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||
(loss, num_labels), grad = grad_fn(state.params)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
# true grad = total grad / total samples
|
||||
grad = jax.lax.psum(grad, "batch")
|
||||
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
|
||||
# summarize metrics
|
||||
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||
num_labels = jax.lax.psum(num_labels, "batch")
|
||||
|
||||
# true loss = total loss / total samples
|
||||
loss = jax.lax.psum(loss, "batch")
|
||||
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
# Define generation function
|
||||
@@ -1024,7 +1030,6 @@ def main():
|
||||
ckpt_dir: str = "",
|
||||
is_prediction=False,
|
||||
):
|
||||
|
||||
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
||||
|
||||
metrics = []
|
||||
@@ -1103,12 +1108,10 @@ def main():
|
||||
logger.info(desc)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
|
||||
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
||||
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
||||
|
||||
if metrics:
|
||||
|
||||
# Save metrics (only for the evaluation/prediction being done along with training)
|
||||
if has_tensorboard and training_args.do_train:
|
||||
write_metric(
|
||||
@@ -1143,7 +1146,6 @@ def main():
|
||||
input_rng = None
|
||||
|
||||
if training_args.do_train:
|
||||
|
||||
cur_step = 0
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
@@ -1166,7 +1168,6 @@ def main():
|
||||
|
||||
# train
|
||||
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
||||
|
||||
cur_step += 1
|
||||
batch = next(train_batches)
|
||||
batch_start = time.time()
|
||||
@@ -1177,7 +1178,6 @@ def main():
|
||||
|
||||
# log and save info
|
||||
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
||||
|
||||
_train_metric = unreplicate(train_metric)
|
||||
desc = (
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
|
||||
@@ -1217,7 +1217,6 @@ def main():
|
||||
|
||||
# log and save info
|
||||
if training_args.logging_steps <= 0:
|
||||
|
||||
logger.info(desc)
|
||||
|
||||
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
||||
|
||||
Reference in New Issue
Block a user