From aecae53377261318e6f13025523f685d4ceda4e8 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 29 Jun 2021 14:02:33 +0530 Subject: [PATCH] [example/flax] add summarization readme (#12393) * add readme * update readme and add requirements * Update examples/flax/summarization/README.md Co-authored-by: Patrick von Platen --- examples/flax/summarization/README.md | 66 ++++++++++++++++++++ examples/flax/summarization/requirements.txt | 5 ++ 2 files changed, 71 insertions(+) create mode 100644 examples/flax/summarization/README.md create mode 100644 examples/flax/summarization/requirements.txt diff --git a/examples/flax/summarization/README.md b/examples/flax/summarization/README.md new file mode 100644 index 0000000000..adc9cb15e3 --- /dev/null +++ b/examples/flax/summarization/README.md @@ -0,0 +1,66 @@ +# Summarization (Seq2Seq model) training examples + +The following example showcases how to finetune a sequence-to-sequence model for summarization +using the JAX/Flax backend. + +JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. +Models written in JAX/Flax are **immutable** and updated in a purely functional +way which enables simple and efficient model parallelism. + +`run_summarization_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it. + +For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below. + +Let's start by creating a model repository to save the trained model and logs. +Here we call the model `"bart-base-xsum"`, but you can change the model name as you like. + +You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that +you are logged in) or via the command line: + +``` +huggingface-cli repo create bart-base-xsum +``` +Next we clone the model repository to add the tokenizer and model files. +``` +git clone https://huggingface.co//bart-base-xsum +``` +To ensure that all tensorboard traces will be uploaded correctly, we need to +track them. You can run the following command inside your model repo to do so. + +``` +cd bart-base-xsum +git lfs track "*tfevents*" +``` + +Great, we have set up our model repository. During training, we will automatically +push the training logs and model weights to the repo. + +Next, let's add a symbolic link to the `run_summarization_flax.py`. + +```bash +export MODEL_DIR="./bart-base-xsum" +ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py +``` + +### Train the model +Next we can run the example script to train the model: + +```bash +python run_summarization_flax.py \ + --output_dir ${MODEL_DIR} \ + --model_name_or_path facebook/bart-base \ + --tokenizer_name facebook/bart-base \ + --dataset_name="xsum" \ + --do_train --do_eval --do_predict --predict_with_generate \ + --num_train_epochs 6 \ + --learning_rate 5e-5 --warmup_steps 0 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --overwrite_output_dir \ + --max_source_length 512 --max_target_length 64 \ + --push_to_hub +``` + +This should finish in 37min, with validation loss and ROUGE2 score of 1.7785 and 17.01 respectively after 6 epochs. training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/OcPfOIgXRMSJqYB4RdK2tA/#scalars). + +> Note that here we used default `generate` arguments, using arguments specific for `xsum` dataset should give better ROUGE scores. diff --git a/examples/flax/summarization/requirements.txt b/examples/flax/summarization/requirements.txt new file mode 100644 index 0000000000..6ab626a17f --- /dev/null +++ b/examples/flax/summarization/requirements.txt @@ -0,0 +1,5 @@ +datasets >= 1.1.3 +jax>=0.2.8 +jaxlib>=0.1.59 +flax>=0.3.4 +optax>=0.0.8