[examples] summarization/bart/finetune.py supports t5 (#3824)

renames `run_bart_sum.py` to `finetune.py`
This commit is contained in:
Sam Shleifer
2020-04-16 15:15:19 -04:00
committed by GitHub
parent 0cec4fab7d
commit f0c96fafd1
5 changed files with 36 additions and 14 deletions

View File

@@ -19,7 +19,7 @@ except ImportError:
logger = logging.getLogger(__name__)
class BartSystem(BaseTransformer):
class SummarizationTrainer(BaseTransformer):
mode = "language-modeling"
@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
# NOTE: this generation will not use the cache.
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids = self.model.generate(
source_ids,
source_mask,
input_ids=source_ids,
attention_mask=source_mask,
num_beams=1,
max_length=80,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
@@ -161,20 +161,20 @@ def main(args):
if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)
model = BartSystem(args)
model = SummarizationTrainer(args)
trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1])
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

View File

@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_type=bart \
--model_name_or_path=bart-large \

View File

@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=cnn_tiny/ \
--model_type=bart \
--model_name_or_path=sshleifer/bart-tiny-random \

View File

@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .run_bart_sum import main
from .finetune import main
from .utils import SummarizationDataset
@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
args = argparse.Namespace(**args_d)
main(args)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())