[Flax] Adapt flax examples to include push_to_hub (#12391)
* fix_torch_device_generate_test * remove @ * finish * correct summary writer * correct push to hub * fix indent * finish * finish * finish * finish * finish Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
a7d0b288fa
commit
2d70c91206
@@ -33,11 +33,37 @@ in Norwegian on a single TPUv3-8 pod.
|
|||||||
|
|
||||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script.
|
Let's start by creating a model repository to save the trained model and logs.
|
||||||
|
Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like.
|
||||||
|
|
||||||
|
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||||
|
you are logged in) or via the command line:
|
||||||
|
|
||||||
|
```
|
||||||
|
huggingface-cli repo create norwegian-roberta-base
|
||||||
|
```
|
||||||
|
|
||||||
|
Next we clone the model repository to add the tokenizer and model files.
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
|
||||||
|
```
|
||||||
|
|
||||||
|
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||||
|
track them. You can run the following command inside your model repo to do so.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd norwegian-roberta-base
|
||||||
|
git lfs track "*tfevents*"
|
||||||
|
```
|
||||||
|
|
||||||
|
Great, we have set up our model repository. During training, we will automatically
|
||||||
|
push the training logs and model weights to the repo.
|
||||||
|
|
||||||
|
Next, let's add a symbolic link to the `run_mlm_flax.py`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export MODEL_DIR="./norwegian-roberta-base"
|
export MODEL_DIR="./norwegian-roberta-base"
|
||||||
mkdir -p ${MODEL_DIR}
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -98,7 +124,7 @@ Next we can run the example script to pretrain the model:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
./run_mlm_flax.py \
|
./run_mlm_flax.py \
|
||||||
--output_dir="./runs" \
|
--output_dir="${MODEL_DIR}" \
|
||||||
--model_type="roberta" \
|
--model_type="roberta" \
|
||||||
--config_name="${MODEL_DIR}" \
|
--config_name="${MODEL_DIR}" \
|
||||||
--tokenizer_name="${MODEL_DIR}" \
|
--tokenizer_name="${MODEL_DIR}" \
|
||||||
@@ -114,7 +140,8 @@ Next we can run the example script to pretrain the model:
|
|||||||
--pad_to_max_length \
|
--pad_to_max_length \
|
||||||
--num_train_epochs="18" \
|
--num_train_epochs="18" \
|
||||||
--adam_beta1="0.9" \
|
--adam_beta1="0.9" \
|
||||||
--adam_beta2="0.98"
|
--adam_beta2="0.98" \
|
||||||
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
Training should converge at a loss and accuracy
|
Training should converge at a loss and accuracy
|
||||||
@@ -135,11 +162,37 @@ in Norwegian on a single TPUv3-8 pod.
|
|||||||
|
|
||||||
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
|
||||||
|
|
||||||
Let's start by creating a folder to save the trained model and a symbolic link to the `run_clm_flax.py` script.
|
Let's start by creating a model repository to save the trained model and logs.
|
||||||
|
Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like.
|
||||||
|
|
||||||
|
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||||
|
you are logged in) or via the command line:
|
||||||
|
|
||||||
|
```
|
||||||
|
huggingface-cli repo create norwegian-gpt2
|
||||||
|
```
|
||||||
|
|
||||||
|
Next we clone the model repository to add the tokenizer and model files.
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://huggingface.co/<your-username>/norwegian-gpt2
|
||||||
|
```
|
||||||
|
|
||||||
|
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||||
|
track them. You can run the following command inside your model repo to do so.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd norwegian-gpt2
|
||||||
|
git lfs track "*tfevents*"
|
||||||
|
```
|
||||||
|
|
||||||
|
Great, we have set up our model repository. During training, we will automatically
|
||||||
|
push the training logs and model weights to the repo.
|
||||||
|
|
||||||
|
Next, let's add a symbolic link to the `run_clm_flax.py`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export MODEL_DIR="./norwegian-gpt2"
|
export MODEL_DIR="./norwegian-gpt2"
|
||||||
mkdir -p ${MODEL_DIR}
|
|
||||||
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
|
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -166,7 +219,7 @@ Next we can run the example script to pretrain the model:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
./run_clm_flax.py \
|
./run_clm_flax.py \
|
||||||
--output_dir="./runs" \
|
--output_dir="${MODEL_DIR}" \
|
||||||
--model_type="gpt2" \
|
--model_type="gpt2" \
|
||||||
--config_name="${MODEL_DIR}" \
|
--config_name="${MODEL_DIR}" \
|
||||||
--tokenizer_name="${MODEL_DIR}" \
|
--tokenizer_name="${MODEL_DIR}" \
|
||||||
@@ -180,6 +233,7 @@ Next we can run the example script to pretrain the model:
|
|||||||
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--num_train_epochs="20" \
|
--num_train_epochs="20" \
|
||||||
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
Training should converge at a loss and perplexity
|
Training should converge at a loss and perplexity
|
||||||
@@ -197,14 +251,9 @@ For reproducibility, we state the training commands used for PyTorch/XLA and PyT
|
|||||||
| Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) |
|
| Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) |
|
||||||
|-------|-----------|------------|------------|
|
|-------|-----------|------------|------------|
|
||||||
| MLM | 15h32m | 23h46m | 44h14m |
|
| MLM | 15h32m | 23h46m | 44h14m |
|
||||||
| **COST*** | $124.24 | $187.84 | $877.92 |
|
|
||||||
|
|
||||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
*All experiments are ran on Google Cloud Platform.
|
||||||
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
|
GPU experiments are ran without further optimizations besides JAX
|
||||||
the following tables:
|
|
||||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
|
|
||||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
|
||||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
|
||||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||||
|
|
||||||
@@ -281,7 +330,7 @@ mkdir -p ${MODEL_DIR}
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
|
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
|
||||||
--output_dir="./runs" \
|
--output_dir="${MODEL_DIR}" \
|
||||||
--model_type="roberta" \
|
--model_type="roberta" \
|
||||||
--config_name="${MODEL_DIR}" \
|
--config_name="${MODEL_DIR}" \
|
||||||
--tokenizer_name="${MODEL_DIR}" \
|
--tokenizer_name="${MODEL_DIR}" \
|
||||||
|
|||||||
11
examples/flax/language-modeling/run_clm_flax.py
Normal file → Executable file
11
examples/flax/language-modeling/run_clm_flax.py
Normal file → Executable file
@@ -451,7 +451,7 @@ def main():
|
|||||||
|
|
||||||
# Enable tensorboard only on the master node
|
# Enable tensorboard only on the master node
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||||
|
|
||||||
# Initialize our training
|
# Initialize our training
|
||||||
rng = jax.random.PRNGKey(training_args.seed)
|
rng = jax.random.PRNGKey(training_args.seed)
|
||||||
@@ -604,10 +604,15 @@ def main():
|
|||||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||||
|
|
||||||
# save last checkpoint
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(unreplicate(state.params))
|
params = jax.device_get(unreplicate(state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
params=params,
|
||||||
|
push_to_hub=training_args.push_to_hub,
|
||||||
|
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
|
|||||||
return batch_idx
|
return batch_idx
|
||||||
|
|
||||||
|
|
||||||
def write_metric(train_metrics, eval_metrics, train_time, step):
|
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||||
summary_writer.scalar("train_time", train_time, step)
|
summary_writer.scalar("train_time", train_time, step)
|
||||||
|
|
||||||
train_metrics = get_metrics(train_metrics)
|
train_metrics = get_metrics(train_metrics)
|
||||||
@@ -472,7 +472,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Enable tensorboard only on the master node
|
# Enable tensorboard only on the master node
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# This one will take care of randomly masking the tokens.
|
# This one will take care of randomly masking the tokens.
|
||||||
@@ -642,9 +642,14 @@ if __name__ == "__main__":
|
|||||||
# Save metrics
|
# Save metrics
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||||
write_metric(train_metrics, eval_metrics, train_time, cur_step)
|
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||||
|
|
||||||
# save last checkpoint
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
params=params,
|
||||||
|
push_to_hub=training_args.push_to_hub,
|
||||||
|
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||||
|
)
|
||||||
|
|||||||
@@ -542,7 +542,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
from flax.metrics.tensorboard import SummaryWriter
|
from flax.metrics.tensorboard import SummaryWriter
|
||||||
|
|
||||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||||
except ImportError as ie:
|
except ImportError as ie:
|
||||||
has_tensorboard = False
|
has_tensorboard = False
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -787,10 +787,15 @@ def main():
|
|||||||
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
||||||
logger.info(desc)
|
logger.info(desc)
|
||||||
|
|
||||||
# save last checkpoint
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(unreplicate(state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(training_args.output_dir, params=params)
|
model.save_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
params=params,
|
||||||
|
push_to_hub=training_args.push_to_hub,
|
||||||
|
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -23,31 +23,68 @@ Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transfor
|
|||||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
|
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
|
||||||
|
|
||||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
To begin with it is recommended to create a model repository to save the trained model and logs.
|
||||||
|
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
|
||||||
|
|
||||||
|
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||||
|
you are logged in) or via the command line:
|
||||||
|
|
||||||
|
```
|
||||||
|
huggingface-cli repo create bert-glue-mrpc-test
|
||||||
|
```
|
||||||
|
|
||||||
|
Next we clone the model repository to add the tokenizer and model files.
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
|
||||||
|
```
|
||||||
|
|
||||||
|
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||||
|
track them. You can run the following command inside your model repo to do so.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd bert-glue-mrpc-test
|
||||||
|
git lfs track "*tfevents*"
|
||||||
|
```
|
||||||
|
|
||||||
|
Great, we have set up our model repository. During training, we will automatically
|
||||||
|
push the training logs and model weights to the repo.
|
||||||
|
|
||||||
|
Next, let's add a symbolic link to the `run_flax_glue.py`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export TASK_NAME=mrpc
|
export TASK_NAME=mrpc
|
||||||
|
export MODEL_DIR="./bert-glue-mrpc-test"
|
||||||
|
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||||
|
|
||||||
|
```bash
|
||||||
python run_flax_glue.py \
|
python run_flax_glue.py \
|
||||||
--model_name_or_path bert-base-cased \
|
--model_name_or_path bert-base-cased \
|
||||||
--task_name $TASK_NAME \
|
--task_name ${TASK_NAME} \
|
||||||
--max_length 128 \
|
--max_length 128 \
|
||||||
--learning_rate 2e-5 \
|
--learning_rate 2e-5 \
|
||||||
--num_train_epochs 3 \
|
--num_train_epochs 3 \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
--output_dir /tmp/$TASK_NAME/
|
--output_dir ${MODEL_DIR} \
|
||||||
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
||||||
|
|
||||||
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
||||||
Metrics and hyperparameters are stored in Tensorflow event files in `---output_dir`.
|
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
|
||||||
You can see the results by running `tensorboard` in that directory:
|
You can see the results by running `tensorboard` in that directory:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ tensorboard --logdir .
|
$ tensorboard --logdir .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
or directly on the hub under *Training metrics*.
|
||||||
|
|
||||||
### Accuracy Evaluation
|
### Accuracy Evaluation
|
||||||
|
|
||||||
We train five replicas and report mean accuracy and stdev on the dev set below.
|
We train five replicas and report mean accuracy and stdev on the dev set below.
|
||||||
@@ -95,14 +132,8 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
|
|||||||
| WNLI | 1m 11s | 48s | 39s | 36s |
|
| WNLI | 1m 11s | 48s | 39s | 36s |
|
||||||
|-------|
|
|-------|
|
||||||
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
|
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
|
||||||
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
|
|
||||||
|
|
||||||
|
*All experiments are ran on Google Cloud Platform.
|
||||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
GPU experiments are ran without further optimizations besides JAX
|
||||||
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
|
|
||||||
the following tables:
|
|
||||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
|
|
||||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
|
||||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
|
||||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||||
|
|||||||
@@ -123,6 +123,11 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
|
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
@@ -491,10 +496,15 @@ def main():
|
|||||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||||
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
||||||
|
|
||||||
# save last checkpoint
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||||
model.save_pretrained(args.output_dir, params=params)
|
model.save_pretrained(
|
||||||
|
args.output_dir,
|
||||||
|
params=params,
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
commit_message=f"Saving weights and logs of epoch {epoch}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user