[seq2seq] make it easier to run the scripts (#7274)
This commit is contained in:
@@ -100,7 +100,7 @@ All finetuning bash scripts call finetune.py (or distillation.py) with reasonabl
|
|||||||
To see all the possible command line options, run:
|
To see all the possible command line options, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./finetune.sh --help # this calls python finetune.py --help
|
./finetune.py --help
|
||||||
```
|
```
|
||||||
|
|
||||||
### Finetuning Training Params
|
### Finetuning Training Params
|
||||||
@@ -197,7 +197,7 @@ If 'translation' is in your task name, the computed metric will be BLEU. Otherwi
|
|||||||
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
|
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
|
||||||
```bash
|
```bash
|
||||||
export DATA_DIR=wmt_en_ro
|
export DATA_DIR=wmt_en_ro
|
||||||
python run_eval.py t5-base \
|
./run_eval.py t5-base \
|
||||||
$DATA_DIR/val.source t5_val_generations.txt \
|
$DATA_DIR/val.source t5_val_generations.txt \
|
||||||
--reference_path $DATA_DIR/val.target \
|
--reference_path $DATA_DIR/val.target \
|
||||||
--score_path enro_bleu.json \
|
--score_path enro_bleu.json \
|
||||||
@@ -211,7 +211,7 @@ python run_eval.py t5-base \
|
|||||||
This command works for MBART, although the BLEU score is suspiciously low.
|
This command works for MBART, although the BLEU score is suspiciously low.
|
||||||
```bash
|
```bash
|
||||||
export DATA_DIR=wmt_en_ro
|
export DATA_DIR=wmt_en_ro
|
||||||
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
|
./run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
|
||||||
--reference_path $DATA_DIR/val.target \
|
--reference_path $DATA_DIR/val.target \
|
||||||
--score_path enro_bleu.json \
|
--score_path enro_bleu.json \
|
||||||
--task translation \
|
--task translation \
|
||||||
@@ -224,7 +224,7 @@ python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_gen
|
|||||||
Summarization (xsum will be very similar):
|
Summarization (xsum will be very similar):
|
||||||
```bash
|
```bash
|
||||||
export DATA_DIR=cnn_dm
|
export DATA_DIR=cnn_dm
|
||||||
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
|
./run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
|
||||||
--reference_path $DATA_DIR/val.target \
|
--reference_path $DATA_DIR/val.target \
|
||||||
--score_path cnn_rouge.json \
|
--score_path cnn_rouge.json \
|
||||||
--task summarization \
|
--task summarization \
|
||||||
@@ -238,7 +238,7 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
|
|||||||
### Multi-GPU Evalulation
|
### Multi-GPU Evalulation
|
||||||
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
|
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
|
||||||
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
|
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
|
||||||
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
|
`{type_path}.source` and `{type_path}.target`. Run `./run_distributed_eval.py --help` for all clargs.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
|
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
|
||||||
@@ -371,11 +371,11 @@ This feature can only be used:
|
|||||||
- with fairseq installed
|
- with fairseq installed
|
||||||
- on 1 GPU
|
- on 1 GPU
|
||||||
- without sortish sampler
|
- without sortish sampler
|
||||||
- after calling `python save_len_file.py $tok $data_dir`
|
- after calling `./save_len_file.py $tok $data_dir`
|
||||||
|
|
||||||
For example,
|
For example,
|
||||||
```bash
|
```bash
|
||||||
python save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
|
./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
|
||||||
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
|
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
|
||||||
```
|
```
|
||||||
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
|
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
|
||||||
|
|||||||
2
examples/seq2seq/convert_model_to_fp16.py
Normal file → Executable file
2
examples/seq2seq/convert_model_to_fp16.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|||||||
2
examples/seq2seq/convert_pl_checkpoint_to_hf.py
Normal file → Executable file
2
examples/seq2seq/convert_pl_checkpoint_to_hf.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|||||||
9
examples/seq2seq/distillation.py
Normal file → Executable file
9
examples/seq2seq/distillation.py
Normal file → Executable file
@@ -1,6 +1,9 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,7 +16,6 @@ from torch.nn import functional as F
|
|||||||
from finetune import SummarizationModule, TranslationModule
|
from finetune import SummarizationModule, TranslationModule
|
||||||
from finetune import main as ft_main
|
from finetune import main as ft_main
|
||||||
from initialization_utils import copy_layers, init_student
|
from initialization_utils import copy_layers, init_student
|
||||||
from lightning_base import generic_train
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from utils import (
|
from utils import (
|
||||||
@@ -27,6 +29,11 @@ from utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# need the parent dir module
|
||||||
|
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||||
|
from lightning_base import generic_train # noqa
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||||
|
|
||||||
|
|||||||
2
examples/seq2seq/download_wmt.py
Normal file → Executable file
2
examples/seq2seq/download_wmt.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|||||||
9
examples/seq2seq/finetune.py
Normal file → Executable file
9
examples/seq2seq/finetune.py
Normal file → Executable file
@@ -1,7 +1,10 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,7 +16,6 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
||||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from utils import (
|
from utils import (
|
||||||
@@ -34,6 +36,11 @@ from utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# need the parent dir module
|
||||||
|
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||||
|
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
# Add parent directory to python path to access lightning_base.py
|
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
|
||||||
|
|
||||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
# run ./finetune.sh --help to see all the possible options
|
# run ./finetune.sh --help to see all the possible options
|
||||||
python finetune.py \
|
python finetune.py \
|
||||||
|
|||||||
2
examples/seq2seq/minify_dataset.py
Normal file → Executable file
2
examples/seq2seq/minify_dataset.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|||||||
2
examples/seq2seq/pack_dataset.py
Normal file → Executable file
2
examples/seq2seq/pack_dataset.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
"""Fill examples with bitext up to max_tokens without breaking up examples.
|
"""Fill examples with bitext up to max_tokens without breaking up examples.
|
||||||
[['I went', 'yo fui'],
|
[['I went', 'yo fui'],
|
||||||
['to the store', 'a la tienda']
|
['to the store', 'a la tienda']
|
||||||
|
|||||||
2
examples/seq2seq/run_distributed_eval.py
Normal file → Executable file
2
examples/seq2seq/run_distributed_eval.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
|||||||
2
examples/seq2seq/run_eval.py
Normal file → Executable file
2
examples/seq2seq/run_eval.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|||||||
2
examples/seq2seq/run_eval_search.py
Normal file → Executable file
2
examples/seq2seq/run_eval_search.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
|
|||||||
9
examples/seq2seq/save_len_file.py
Normal file → Executable file
9
examples/seq2seq/save_len_file.py
Normal file → Executable file
@@ -1,14 +1,11 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from utils import Seq2SeqDataset, pickle_save
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import Seq2SeqDataset, pickle_save
|
|
||||||
except ImportError:
|
|
||||||
from utils import Seq2SeqDataset, pickle_save
|
|
||||||
|
|
||||||
|
|
||||||
def save_len_file(
|
def save_len_file(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -6,14 +6,13 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from pack_dataset import pack_data_dir
|
||||||
|
from save_len_file import save_len_file
|
||||||
|
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||||
from .pack_dataset import pack_data_dir
|
|
||||||
from .save_len_file import save_len_file
|
|
||||||
from .test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
|
||||||
from .utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
|
||||||
|
|
||||||
|
|
||||||
BERT_BASE_CASED = "bert-base-cased"
|
BERT_BASE_CASED = "bert-base-cased"
|
||||||
|
|||||||
@@ -14,19 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import unittest
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import calculate_bleu
|
|
||||||
except ImportError:
|
|
||||||
from utils import calculate_bleu
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
||||||
|
from utils import calculate_bleu
|
||||||
|
|
||||||
|
|
||||||
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
|
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
|
||||||
|
|||||||
Reference in New Issue
Block a user