[Flax] Adapt examples to be able to use eval_steps and save_steps (#12543)

* fix_torch_device_generate_test

* remove @

* up

* up

* correct

* upload

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-07-06 19:41:51 +01:00
committed by GitHub
parent 2870fd198f
commit 208df208bf
4 changed files with 107 additions and 95 deletions

View File

@@ -141,6 +141,8 @@ Next we can run the example script to pretrain the model:
--adam_beta1="0.9" \ --adam_beta1="0.9" \
--adam_beta2="0.98" \ --adam_beta2="0.98" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```
@@ -234,6 +236,8 @@ Next we can run the example script to pretrain the model:
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="20" \ --num_train_epochs="20" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```
@@ -370,6 +374,8 @@ Next we can run the example script to pretrain the model:
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="10" \ --num_train_epochs="10" \
--logging_steps="500" \ --logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub --push_to_hub
``` ```

View File

@@ -587,45 +587,46 @@ def main():
train_metrics = [] train_metrics = []
# ======================== Evaluating ============================== if cur_step % training_args.eval_steps == 0 and cur_step > 0:
eval_metrics = [] # ======================== Evaluating ==============================
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_metrics = []
eval_steps = len(eval_dataset) // eval_batch_size eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): eval_steps = len(eval_dataset) // eval_batch_size
# Model forward for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
batch = next(eval_loader) # Model forward
metrics = p_eval_step(state.params, batch) batch = next(eval_loader)
eval_metrics.append(metrics) metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
try: # Print metrics and update progress bar
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
except OverflowError: epochs.write(desc)
eval_metrics["perplexity"] = float("inf") epochs.desc = desc
# Print metrics and update progress bar # Save metrics
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" if has_tensorboard and jax.process_index() == 0:
epochs.write(desc) cur_step = epoch * (len(train_dataset) // train_batch_size)
epochs.desc = desc write_eval_metric(summary_writer, eval_metrics, cur_step)
# Save metrics if cur_step % training_args.save_steps == 0 and cur_step > 0:
if has_tensorboard and jax.process_index() == 0: # save checkpoint after each epoch and push checkpoint to the hub
cur_step = epoch * (len(train_dataset) // train_batch_size) if jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step) params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
# save checkpoint after each epoch and push checkpoint to the hub training_args.output_dir,
if jax.process_index() == 0: params=params,
params = jax.device_get(unreplicate(state.params)) push_to_hub=training_args.push_to_hub,
model.save_pretrained( commit_message=f"Saving weights and logs of step {cur_step}",
training_args.output_dir, )
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -621,43 +621,43 @@ if __name__ == "__main__":
train_metrics = [] train_metrics = []
# ======================== Evaluating ============================== if cur_step % training_args.eval_steps == 0 and cur_step > 0:
num_eval_samples = len(tokenized_datasets["validation"]) # ======================== Evaluating ==============================
eval_samples_idx = jnp.arange(num_eval_samples) num_eval_samples = len(tokenized_datasets["validation"])
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16) model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs) metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# normalize eval metrics # normalize eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics) eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer") eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar # Update progress bar
epochs.desc = ( epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
)
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub if cur_step % training_args.save_steps == 0 and cur_step > 0:
if jax.process_index() == 0: # save checkpoint after each epoch and push checkpoint to the hub
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) if jax.process_index() == 0:
model.save_pretrained( params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
training_args.output_dir, model.save_pretrained(
params=params, training_args.output_dir,
push_to_hub=training_args.push_to_hub, params=params,
commit_message=f"Saving weights and logs of epoch {epoch+1}", push_to_hub=training_args.push_to_hub,
) commit_message=f"Saving weights and logs of step {cur_step}",
)

View File

@@ -737,36 +737,41 @@ if __name__ == "__main__":
train_metrics = [] train_metrics = []
# ======================== Evaluating ============================== if cur_step % training_args.eval_steps == 0 and cur_step > 0:
num_eval_samples = len(tokenized_datasets["validation"]) # ======================== Evaluating ==============================
eval_samples_idx = jnp.arange(num_eval_samples) num_eval_samples = len(tokenized_datasets["validation"])
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_samples_idx = jnp.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples) model_inputs = data_collator(samples)
# Model forward # Model forward
model_inputs = shard(model_inputs.data) model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs) metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics) eval_metrics.append(metrics)
# get eval metrics # get eval metrics
eval_metrics = get_metrics(eval_metrics) eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Update progress bar # Update progress bar
epochs.write( epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
)
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step) write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub if cur_step % training_args.save_steps == 0 and cur_step > 0:
if jax.process_index() == 0: # save checkpoint after each epoch and push checkpoint to the hub
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) if jax.process_index() == 0:
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)