[Flax] Adapt examples to be able to use eval_steps and save_steps (#12543)
* fix_torch_device_generate_test * remove @ * up * up * correct * upload Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2870fd198f
commit
208df208bf
@@ -141,6 +141,8 @@ Next we can run the example script to pretrain the model:
|
|||||||
--adam_beta1="0.9" \
|
--adam_beta1="0.9" \
|
||||||
--adam_beta2="0.98" \
|
--adam_beta2="0.98" \
|
||||||
--logging_steps="500" \
|
--logging_steps="500" \
|
||||||
|
--save_steps="2500" \
|
||||||
|
--eval_steps="2500" \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -234,6 +236,8 @@ Next we can run the example script to pretrain the model:
|
|||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--num_train_epochs="20" \
|
--num_train_epochs="20" \
|
||||||
--logging_steps="500" \
|
--logging_steps="500" \
|
||||||
|
--save_steps="2500" \
|
||||||
|
--eval_steps="2500" \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -370,6 +374,8 @@ Next we can run the example script to pretrain the model:
|
|||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--num_train_epochs="10" \
|
--num_train_epochs="10" \
|
||||||
--logging_steps="500" \
|
--logging_steps="500" \
|
||||||
|
--save_steps="2500" \
|
||||||
|
--eval_steps="2500" \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -587,45 +587,46 @@ def main():
|
|||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
eval_metrics = []
|
# ======================== Evaluating ==============================
|
||||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
eval_metrics = []
|
||||||
eval_steps = len(eval_dataset) // eval_batch_size
|
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
eval_steps = len(eval_dataset) // eval_batch_size
|
||||||
# Model forward
|
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||||
batch = next(eval_loader)
|
# Model forward
|
||||||
metrics = p_eval_step(state.params, batch)
|
batch = next(eval_loader)
|
||||||
eval_metrics.append(metrics)
|
metrics = p_eval_step(state.params, batch)
|
||||||
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
|
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
try:
|
||||||
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||||
|
except OverflowError:
|
||||||
|
eval_metrics["perplexity"] = float("inf")
|
||||||
|
|
||||||
try:
|
# Print metrics and update progress bar
|
||||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
||||||
except OverflowError:
|
epochs.write(desc)
|
||||||
eval_metrics["perplexity"] = float("inf")
|
epochs.desc = desc
|
||||||
|
|
||||||
# Print metrics and update progress bar
|
# Save metrics
|
||||||
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
epochs.write(desc)
|
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||||
epochs.desc = desc
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||||
|
|
||||||
# Save metrics
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
if jax.process_index() == 0:
|
||||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
params = jax.device_get(unreplicate(state.params))
|
||||||
|
model.save_pretrained(
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
training_args.output_dir,
|
||||||
if jax.process_index() == 0:
|
params=params,
|
||||||
params = jax.device_get(unreplicate(state.params))
|
push_to_hub=training_args.push_to_hub,
|
||||||
model.save_pretrained(
|
commit_message=f"Saving weights and logs of step {cur_step}",
|
||||||
training_args.output_dir,
|
)
|
||||||
params=params,
|
|
||||||
push_to_hub=training_args.push_to_hub,
|
|
||||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -621,43 +621,43 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
# ======================== Evaluating ==============================
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||||
|
|
||||||
# Model forward
|
# Model forward
|
||||||
model_inputs = shard(model_inputs.data)
|
model_inputs = shard(model_inputs.data)
|
||||||
metrics = p_eval_step(state.params, model_inputs)
|
metrics = p_eval_step(state.params, model_inputs)
|
||||||
eval_metrics.append(metrics)
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
# normalize eval metrics
|
# normalize eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
||||||
eval_normalizer = eval_metrics.pop("normalizer")
|
eval_normalizer = eval_metrics.pop("normalizer")
|
||||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.desc = (
|
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save metrics
|
# Save metrics
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
if jax.process_index() == 0:
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
if jax.process_index() == 0:
|
||||||
model.save_pretrained(
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
training_args.output_dir,
|
model.save_pretrained(
|
||||||
params=params,
|
training_args.output_dir,
|
||||||
push_to_hub=training_args.push_to_hub,
|
params=params,
|
||||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
push_to_hub=training_args.push_to_hub,
|
||||||
)
|
commit_message=f"Saving weights and logs of step {cur_step}",
|
||||||
|
)
|
||||||
|
|||||||
@@ -737,36 +737,41 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
# ======================== Evaluating ==============================
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||||
model_inputs = data_collator(samples)
|
model_inputs = data_collator(samples)
|
||||||
|
|
||||||
# Model forward
|
# Model forward
|
||||||
model_inputs = shard(model_inputs.data)
|
model_inputs = shard(model_inputs.data)
|
||||||
metrics = p_eval_step(state.params, model_inputs)
|
metrics = p_eval_step(state.params, model_inputs)
|
||||||
eval_metrics.append(metrics)
|
eval_metrics.append(metrics)
|
||||||
|
|
||||||
# get eval metrics
|
# get eval metrics
|
||||||
eval_metrics = get_metrics(eval_metrics)
|
eval_metrics = get_metrics(eval_metrics)
|
||||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
epochs.write(
|
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save metrics
|
# Save metrics
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||||
if jax.process_index() == 0:
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
if jax.process_index() == 0:
|
||||||
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
|
model.save_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
params=params,
|
||||||
|
push_to_hub=training_args.push_to_hub,
|
||||||
|
commit_message=f"Saving weights and logs of step {cur_step}",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user