[Re-submit] Compute true loss Flax examples (#19504)
* Compute true loss * fixup * final * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map * Compute true loss * final * fixup * final * final * Update examples/flax/language-modeling/run_bart_dlm_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * jax.tree_map => jax.tree_util.tree_map Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|||||||
batch_idx = np.arange(len(dataset))
|
batch_idx = np.arange(len(dataset))
|
||||||
|
|
||||||
for idx in range(steps):
|
for idx in range(steps):
|
||||||
|
|
||||||
start_idx = batch_size * idx
|
start_idx = batch_size * idx
|
||||||
end_idx = batch_size * (idx + 1)
|
end_idx = batch_size * (idx + 1)
|
||||||
|
|
||||||
@@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|||||||
|
|
||||||
|
|
||||||
def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
|
def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
|
||||||
|
|
||||||
if train_time:
|
if train_time:
|
||||||
summary_writer.scalar("train_time", train_time, step)
|
summary_writer.scalar("train_time", train_time, step)
|
||||||
|
|
||||||
@@ -782,11 +780,9 @@ def main():
|
|||||||
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
|
||||||
|
|
||||||
for idx in range(num_splits):
|
for idx in range(num_splits):
|
||||||
|
|
||||||
if not block_size:
|
if not block_size:
|
||||||
_ds = ds
|
_ds = ds
|
||||||
else:
|
else:
|
||||||
|
|
||||||
start_idx = block_size * idx
|
start_idx = block_size * idx
|
||||||
end_idx = block_size * (idx + 1)
|
end_idx = block_size * (idx + 1)
|
||||||
|
|
||||||
@@ -926,8 +922,9 @@ def main():
|
|||||||
|
|
||||||
# ignore padded tokens from loss
|
# ignore padded tokens from loss
|
||||||
loss = loss * padding_mask
|
loss = loss * padding_mask
|
||||||
loss = loss.sum() / padding_mask.sum()
|
loss = loss.sum()
|
||||||
return loss
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
@@ -936,29 +933,38 @@ def main():
|
|||||||
def compute_loss(params):
|
def compute_loss(params):
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
return loss
|
return loss, num_labels
|
||||||
|
|
||||||
grad_fn = jax.value_and_grad(compute_loss)
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
loss, grad = grad_fn(state.params)
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||||
|
|
||||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
||||||
|
|
||||||
return new_state, metrics
|
return new_state, metrics
|
||||||
|
|
||||||
# Define eval fn
|
# Define eval fn
|
||||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
logits = model(**batch, params=params, train=False)[0]
|
logits = model(**batch, params=params, train=False)[0]
|
||||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
||||||
|
|
||||||
# summarize metrics
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
metrics = {"loss": loss}
|
metrics = {"loss": loss}
|
||||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
@@ -1024,7 +1030,6 @@ def main():
|
|||||||
ckpt_dir: str = "",
|
ckpt_dir: str = "",
|
||||||
is_prediction=False,
|
is_prediction=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
||||||
|
|
||||||
metrics = []
|
metrics = []
|
||||||
@@ -1103,12 +1108,10 @@ def main():
|
|||||||
logger.info(desc)
|
logger.info(desc)
|
||||||
|
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
|
|
||||||
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
|
||||||
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
|
||||||
|
|
||||||
if metrics:
|
if metrics:
|
||||||
|
|
||||||
# Save metrics (only for the evaluation/prediction being done along with training)
|
# Save metrics (only for the evaluation/prediction being done along with training)
|
||||||
if has_tensorboard and training_args.do_train:
|
if has_tensorboard and training_args.do_train:
|
||||||
write_metric(
|
write_metric(
|
||||||
@@ -1143,7 +1146,6 @@ def main():
|
|||||||
input_rng = None
|
input_rng = None
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
train_time = 0
|
train_time = 0
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
@@ -1166,7 +1168,6 @@ def main():
|
|||||||
|
|
||||||
# train
|
# train
|
||||||
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
|
||||||
|
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
batch = next(train_batches)
|
batch = next(train_batches)
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
@@ -1177,7 +1178,6 @@ def main():
|
|||||||
|
|
||||||
# log and save info
|
# log and save info
|
||||||
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
|
||||||
|
|
||||||
_train_metric = unreplicate(train_metric)
|
_train_metric = unreplicate(train_metric)
|
||||||
desc = (
|
desc = (
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
|
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
|
||||||
@@ -1217,7 +1217,6 @@ def main():
|
|||||||
|
|
||||||
# log and save info
|
# log and save info
|
||||||
if training_args.logging_steps <= 0:
|
if training_args.logging_steps <= 0:
|
||||||
|
|
||||||
logger.info(desc)
|
logger.info(desc)
|
||||||
|
|
||||||
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them
|
|||||||
To setup all relevant files for training, let's create a directory.
|
To setup all relevant files for training, let's create a directory.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mkdir ./norwegian-roberta-base
|
mkdir ./norwegian-bart-base
|
||||||
```
|
```
|
||||||
|
|
||||||
### Train tokenizer
|
### Train tokenizer
|
||||||
|
|||||||
@@ -799,19 +799,25 @@ def main():
|
|||||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||||
|
|
||||||
# take average
|
# take average
|
||||||
loss = loss.sum() / label_mask.sum()
|
loss = loss.sum()
|
||||||
|
num_labels = label_mask.sum()
|
||||||
|
|
||||||
return loss
|
return loss, num_labels
|
||||||
|
|
||||||
grad_fn = jax.value_and_grad(loss_fn)
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||||
loss, grad = grad_fn(state.params)
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
new_state = state.apply_gradients(grads=grad)
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
|
||||||
metrics = jax.lax.pmean(
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_state, metrics, new_dropout_rng
|
return new_state, metrics, new_dropout_rng
|
||||||
|
|
||||||
# Create parallel version of the train step
|
# Create parallel version of the train step
|
||||||
@@ -888,7 +894,7 @@ def main():
|
|||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
# Avoid using jax.numpy here in case of TPU training
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
eval_samples_idx = np.arange(num_eval_samples)
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
@@ -903,9 +909,9 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||||
@@ -917,7 +923,7 @@ def main():
|
|||||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
@@ -928,7 +934,7 @@ def main():
|
|||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
# Avoid using jax.numpy here in case of TPU training
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
eval_samples_idx = np.arange(num_eval_samples)
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
@@ -943,9 +949,9 @@ def main():
|
|||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
perplexity = math.exp(eval_metrics["loss"])
|
perplexity = math.exp(eval_metrics["loss"])
|
||||||
|
|||||||
@@ -723,18 +723,25 @@ def main():
|
|||||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
||||||
|
|
||||||
# take average
|
# take average
|
||||||
loss = loss.sum() / label_mask.sum()
|
loss = loss.sum()
|
||||||
|
num_labels = label_mask.sum()
|
||||||
|
|
||||||
return loss
|
return loss, num_labels
|
||||||
|
|
||||||
grad_fn = jax.value_and_grad(loss_fn)
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||||
loss, grad = grad_fn(state.params)
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
new_state = state.apply_gradients(grads=grad)
|
new_state = state.apply_gradients(grads=grad)
|
||||||
|
|
||||||
metrics = jax.lax.pmean(
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_state, metrics, new_dropout_rng
|
return new_state, metrics, new_dropout_rng
|
||||||
|
|
||||||
|
|||||||
@@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
decoder_start_token_id: int
|
decoder_start_token_id: int
|
||||||
|
|
||||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
|
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
|
||||||
|
|
||||||
# convert list to dict and tensorize input
|
# convert list to dict and tensorize input
|
||||||
batch = BatchEncoding(
|
batch = BatchEncoding(
|
||||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||||
@@ -397,7 +396,6 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def random_spans_noise_mask(self, length):
|
def random_spans_noise_mask(self, length):
|
||||||
|
|
||||||
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||||
|
|
||||||
Noise mask consisting of random spans of noise tokens.
|
Noise mask consisting of random spans of noise tokens.
|
||||||
|
|||||||
@@ -784,8 +784,9 @@ def main():
|
|||||||
|
|
||||||
# ignore padded tokens from loss
|
# ignore padded tokens from loss
|
||||||
loss = loss * padding_mask
|
loss = loss * padding_mask
|
||||||
loss = loss.sum() / padding_mask.sum()
|
loss = loss.sum()
|
||||||
return loss
|
num_labels = padding_mask.sum()
|
||||||
|
return loss, num_labels
|
||||||
|
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
def train_step(state, batch, label_smoothing_factor=0.0):
|
def train_step(state, batch, label_smoothing_factor=0.0):
|
||||||
@@ -794,29 +795,38 @@ def main():
|
|||||||
def compute_loss(params):
|
def compute_loss(params):
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
return loss
|
return loss, num_labels
|
||||||
|
|
||||||
grad_fn = jax.value_and_grad(compute_loss)
|
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
||||||
loss, grad = grad_fn(state.params)
|
(loss, num_labels), grad = grad_fn(state.params)
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
|
# true grad = total grad / total samples
|
||||||
|
grad = jax.lax.psum(grad, "batch")
|
||||||
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
||||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||||
|
|
||||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
||||||
|
|
||||||
return new_state, metrics
|
return new_state, metrics
|
||||||
|
|
||||||
# Define eval fn
|
# Define eval fn
|
||||||
def eval_step(params, batch, label_smoothing_factor=0.0):
|
def eval_step(params, batch, label_smoothing_factor=0.0):
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
logits = model(**batch, params=params, train=False)[0]
|
logits = model(**batch, params=params, train=False)[0]
|
||||||
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
|
||||||
|
|
||||||
# summarize metrics
|
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
||||||
|
num_labels = jax.lax.psum(num_labels, "batch")
|
||||||
|
|
||||||
|
# true loss = total loss / total samples
|
||||||
|
loss = jax.lax.psum(loss, "batch")
|
||||||
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
||||||
|
|
||||||
metrics = {"loss": loss}
|
metrics = {"loss": loss}
|
||||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
# Define generation function
|
# Define generation function
|
||||||
|
|||||||
Reference in New Issue
Block a user