[examples] summarization/bart/finetune.py supports t5 (#3824)
renames `run_bart_sum.py` to `finetune.py`
This commit is contained in:
@@ -19,7 +19,7 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BartSystem(BaseTransformer):
|
class SummarizationTrainer(BaseTransformer):
|
||||||
|
|
||||||
mode = "language-modeling"
|
mode = "language-modeling"
|
||||||
|
|
||||||
@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
|
|||||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
# NOTE: this generation will not use the cache.
|
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
|
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
source_ids,
|
input_ids=source_ids,
|
||||||
source_mask,
|
attention_mask=source_mask,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=80,
|
max_length=80,
|
||||||
repetition_penalty=2.5,
|
repetition_penalty=2.5,
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
preds = [
|
preds = [
|
||||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
@@ -161,20 +161,20 @@ def main(args):
|
|||||||
if not args.output_dir:
|
if not args.output_dir:
|
||||||
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
model = BartSystem(args)
|
model = SummarizationTrainer(args)
|
||||||
trainer = generic_train(model, args)
|
trainer = generic_train(model, args)
|
||||||
|
|
||||||
# Optionally, predict on dev set and write to output_dir
|
# Optionally, predict on dev set and write to output_dir
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||||
BartSystem.load_from_checkpoint(checkpoints[-1])
|
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
add_generic_args(parser, os.getcwd())
|
add_generic_args(parser, os.getcwd())
|
||||||
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
|
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
|
|||||||
# Add parent directory to python path to access transformer_base.py
|
# Add parent directory to python path to access transformer_base.py
|
||||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||||
|
|
||||||
python run_bart_sum.py \
|
python finetune.py \
|
||||||
--data_dir=./cnn-dailymail/cnn_dm \
|
--data_dir=./cnn-dailymail/cnn_dm \
|
||||||
--model_type=bart \
|
--model_type=bart \
|
||||||
--model_name_or_path=bart-large \
|
--model_name_or_path=bart-large \
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
|
|||||||
|
|
||||||
# Add parent directory to python path to access transformer_base.py and utils.py
|
# Add parent directory to python path to access transformer_base.py and utils.py
|
||||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||||
python run_bart_sum.py \
|
python finetune.py \
|
||||||
--data_dir=cnn_tiny/ \
|
--data_dir=cnn_tiny/ \
|
||||||
--model_type=bart \
|
--model_type=bart \
|
||||||
--model_name_or_path=sshleifer/bart-tiny-random \
|
--model_name_or_path=sshleifer/bart-tiny-random \
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
|
|||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
from .evaluate_cnn import run_generate
|
from .evaluate_cnn import run_generate
|
||||||
from .run_bart_sum import main
|
from .finetune import main
|
||||||
from .utils import SummarizationDataset
|
from .utils import SummarizationDataset
|
||||||
|
|
||||||
|
|
||||||
@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
|
|||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
|
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
|
||||||
)
|
)
|
||||||
|
main(argparse.Namespace(**args_d))
|
||||||
|
args_d.update({"do_train": False, "do_predict": True})
|
||||||
|
main(argparse.Namespace(**args_d))
|
||||||
|
|
||||||
args = argparse.Namespace(**args_d)
|
def test_t5_run_sum_cli(self):
|
||||||
main(args)
|
args_d: dict = DEFAULT_ARGS.copy()
|
||||||
|
tmp_dir = make_test_data_dir()
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
|
args_d.update(
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
model_type="t5",
|
||||||
|
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||||
|
train_batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
n_gpu=0,
|
||||||
|
output_dir=output_dir,
|
||||||
|
do_predict=True,
|
||||||
|
)
|
||||||
|
main(argparse.Namespace(**args_d))
|
||||||
|
# args_d.update({"do_train": False, "do_predict": True})
|
||||||
|
# main(argparse.Namespace(**args_d))
|
||||||
|
|
||||||
def test_bart_summarization_dataset(self):
|
def test_bart_summarization_dataset(self):
|
||||||
tmp_dir = Path(tempfile.gettempdir())
|
tmp_dir = Path(tempfile.gettempdir())
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
|
|||||||
wc -l cnn_articles_reference_summaries.txt # should print 11490
|
wc -l cnn_articles_reference_summaries.txt # should print 11490
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Generating Summaries
|
||||||
|
|
||||||
To create summaries for each article in dataset, run:
|
To create summaries for each article in dataset, run:
|
||||||
```bash
|
```bash
|
||||||
@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
|
|||||||
```
|
```
|
||||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||||
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
|
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
|
||||||
|
|
||||||
|
|
||||||
|
### Finetuning
|
||||||
|
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
|
||||||
|
|||||||
Reference in New Issue
Block a user