split seq2seq script into summarization & translation (#10611)
* split seq2seq script, update docs * needless diff * fix readme * remove test diff * s/summarization/translation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * cr * fix arguments & better mbart/t5 refs * copyright Co-authored-by: Suraj Patil <surajp815@gmail.com> * reword readme Co-authored-by: Suraj Patil <surajp815@gmail.com> * s/summarization/translation * short script names * fix tests * fix isort, include mbart doc * delete old script, update tests * automate source prefix * automate source prefix for translation * s/translation/trans Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * fix script name (short version) * typos Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * exact parameter Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * remove superfluous source_prefix calls in docs * rename scripts & warn for source prefix * black * flake8 Co-authored-by: theo <theo@matussie.re> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,7 @@ For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2s
|
||||
- `FSMTForConditionalGeneration` (translation only)
|
||||
- `T5ForConditionalGeneration`
|
||||
|
||||
`run_seq2seq.py` is a lightweight example of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
||||
`run_summarization.py` and `run_translation.py` are lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
||||
|
||||
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files
|
||||
and you also will find examples of these below.
|
||||
@@ -39,11 +39,10 @@ and you also will find examples of these below.
|
||||
|
||||
Here is an example on a summarization task:
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_summarization.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task summarization \
|
||||
--dataset_name xsum \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
--per_device_train_batch_size=4 \
|
||||
@@ -60,11 +59,10 @@ And here is how you would use it on your own files, after adjusting the values f
|
||||
`--train_file`, `--validation_file`, `--text_column` and `--summary_column` to match your setup:
|
||||
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_summarization.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task summarization \
|
||||
--train_file path_to_csv_or_jsonlines_file \
|
||||
--validation_file path_to_csv_or_jsonlines_file \
|
||||
--output_dir /tmp/tst-summarization \
|
||||
@@ -140,14 +138,14 @@ And as with the CSV files, you can specify which values to select from the file,
|
||||
Here is an example of a translation fine-tuning with T5:
|
||||
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_translation.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_ro \
|
||||
--source_lang en \
|
||||
--target_lang ro \
|
||||
--dataset_name wmt16 \
|
||||
--dataset_config_name ro-en \
|
||||
--source_prefix "translate English to Romanian: " \
|
||||
--output_dir /tmp/tst-translation \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
@@ -160,11 +158,10 @@ python examples/seq2seq/run_seq2seq.py \
|
||||
And the same with MBart:
|
||||
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_translation.py \
|
||||
--model_name_or_path facebook/mbart-large-en-ro \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_ro \
|
||||
--dataset_name wmt16 \
|
||||
--dataset_config_name ro-en \
|
||||
--source_lang en_XX \
|
||||
@@ -180,18 +177,8 @@ python examples/seq2seq/run_seq2seq.py \
|
||||
|
||||
Note, that depending on the used model additional language-specific command-line arguments are sometimes required. Specifically:
|
||||
|
||||
* MBart models require:
|
||||
```
|
||||
--source_lang en_XX \
|
||||
--target_lang ro_RO \
|
||||
```
|
||||
* T5 requires:
|
||||
|
||||
```
|
||||
--source_prefix "translate English to Romanian: "
|
||||
```
|
||||
|
||||
* yet, other models, require neither.
|
||||
* MBart models require different `--{source,target}_lang` values, e.g. in place of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be looked up [here](https://huggingface.co/facebook/mbart-large-cc25)
|
||||
* T5 models can use a `--source_prefix` argument to override the otherwise automated prefix of the form `translate {source_lang} to {target_lang}` for `run_translation.py` and `summarize: ` for `run_summarization.py`
|
||||
|
||||
Also, if you switch to a different language pair, make sure to adjust the source and target values in all command line arguments.
|
||||
|
||||
@@ -199,14 +186,14 @@ And here is how you would use the translation finetuning on your own files, afte
|
||||
values for the arguments `--train_file`, `--validation_file` to match your setup:
|
||||
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_translation.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_ro \
|
||||
--source_lang en \
|
||||
--target_lang ro \
|
||||
--dataset_name wmt16 \
|
||||
--dataset_config_name ro-en \
|
||||
--source_prefix "translate English to Romanian: " \
|
||||
--train_file path_to_jsonlines_file \
|
||||
--validation_file path_to_jsonlines_file \
|
||||
--output_dir /tmp/tst-translation \
|
||||
@@ -229,13 +216,13 @@ Here the languages are Romanian (`ro`) and English (`en`).
|
||||
If you want to use a pre-processed dataset that leads to high bleu scores, but for the `en-de` language pair, you can use `--dataset_name wmt14-en-de-pre-processed`, as following:
|
||||
|
||||
```bash
|
||||
python examples/seq2seq/run_seq2seq.py \
|
||||
python examples/seq2seq/run_translation.py \
|
||||
--model_name_or_path t5-small \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--task translation_en_to_de \
|
||||
--source_lang en \
|
||||
--target_lang de \
|
||||
--dataset_name wmt14-en-de-pre-processed \
|
||||
--source_prefix "translate English to German: " \
|
||||
--output_dir /tmp/tst-translation \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
|
||||
Reference in New Issue
Block a user