diff --git a/examples/requirements.txt b/examples/requirements.txt index 70f1b9999a..33eb2ace3c 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,4 +5,5 @@ seqeval psutil sacrebleu rouge-score -tensorflow_datasets \ No newline at end of file +tensorflow_datasets +pytorch-lightning==0.7.3 # April 10, 2020 release diff --git a/examples/summarization/bart/run_bart_sum.py b/examples/summarization/bart/run_bart_sum.py index 3ed6feaf43..5d7876dd5f 100644 --- a/examples/summarization/bart/run_bart_sum.py +++ b/examples/summarization/bart/run_bart_sum.py @@ -8,7 +8,12 @@ import torch from torch.utils.data import DataLoader from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup -from utils import SummarizationDataset + + +try: + from .utils import SummarizationDataset +except ImportError: + from utils import SummarizationDataset logger = logging.getLogger(__name__) @@ -20,6 +25,11 @@ class BartSystem(BaseTransformer): def __init__(self, hparams): super().__init__(hparams, num_labels=None, mode=self.mode) + self.dataset_kwargs: dict = dict( + data_dir=self.hparams.data_dir, + max_source_length=self.hparams.max_source_length, + max_target_length=self.hparams.max_target_length, + ) def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): return self.model( @@ -92,14 +102,6 @@ class BartSystem(BaseTransformer): return self.test_end(outputs) - @property - def dataset_kwargs(self): - return dict( - data_dir=self.hparams.data_dir, - max_source_length=self.hparams.max_source_length, - max_target_length=self.hparams.max_target_length, - ) - def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader: dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) @@ -153,17 +155,12 @@ class BartSystem(BaseTransformer): return parser -if __name__ == "__main__": - parser = argparse.ArgumentParser() - add_generic_args(parser, os.getcwd()) - parser = BartSystem.add_model_specific_args(parser, os.getcwd()) - args = parser.parse_args() +def main(args): # If output_dir not provided, a folder will be generated in pwd if not args.output_dir: args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) os.makedirs(args.output_dir) - model = BartSystem(args) trainer = generic_train(model, args) @@ -172,3 +169,12 @@ if __name__ == "__main__": checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) BartSystem.load_from_checkpoint(checkpoints[-1]) trainer.test(model) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_generic_args(parser, os.getcwd()) + parser = BartSystem.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + + main(args) diff --git a/examples/summarization/bart/run_train.sh b/examples/summarization/bart/run_train.sh index dfdcb14833..d9bd627633 100755 --- a/examples/summarization/bart/run_train.sh +++ b/examples/summarization/bart/run_train.sh @@ -1,7 +1,3 @@ -# Install newest ptl. -pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/ - - export OUTPUT_DIR_NAME=bart_sum export CURRENT_DIR=${PWD} export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} @@ -20,4 +16,4 @@ python run_bart_sum.py \ --train_batch_size=4 \ --eval_batch_size=4 \ --output_dir=$OUTPUT_DIR \ ---do_train \ No newline at end of file +--do_train $@ diff --git a/examples/summarization/bart/run_train_tiny.sh b/examples/summarization/bart/run_train_tiny.sh new file mode 100755 index 0000000000..841882096c --- /dev/null +++ b/examples/summarization/bart/run_train_tiny.sh @@ -0,0 +1,33 @@ +# Script for verifying that run_bart_sum can be invoked from its directory + +# Get tiny dataset with cnn_dm format (4 examples for train, val, test) +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz +tar -xzvf cnn_tiny.tgz +rm cnn_tiny.tgz + +export OUTPUT_DIR_NAME=bart_utest_output +export CURRENT_DIR=${PWD} +export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} + +# Make output directory if it doesn't exist +mkdir -p $OUTPUT_DIR + +# Add parent directory to python path to access transformer_base.py and utils.py +export PYTHONPATH="../../":"${PYTHONPATH}" +python run_bart_sum.py \ +--data_dir=cnn_tiny/ \ +--model_type=bart \ +--model_name_or_path=sshleifer/bart-tiny-random \ +--learning_rate=3e-5 \ +--train_batch_size=2 \ +--eval_batch_size=2 \ +--output_dir=$OUTPUT_DIR \ +--num_train_epochs=1 \ +--n_gpu=0 \ +--do_train $@ + +rm -rf cnn_tiny +rm -rf $OUTPUT_DIR + + + diff --git a/examples/summarization/bart/test_bart_examples.py b/examples/summarization/bart/test_bart_examples.py index b136edfdb7..8d2e7bf93e 100644 --- a/examples/summarization/bart/test_bart_examples.py +++ b/examples/summarization/bart/test_bart_examples.py @@ -1,4 +1,6 @@ +import argparse import logging +import os import sys import tempfile import unittest @@ -10,6 +12,7 @@ from torch.utils.data import DataLoader from transformers import BartTokenizer from .evaluate_cnn import run_generate +from .run_bart_sum import main from .utils import SummarizationDataset @@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() +DEFAULT_ARGS = { + "output_dir": "", + "fp16": False, + "fp16_opt_level": "O1", + "n_gpu": 1, + "n_tpu_cores": 0, + "max_grad_norm": 1.0, + "do_train": True, + "do_predict": False, + "gradient_accumulation_steps": 1, + "server_ip": "", + "server_port": "", + "seed": 42, + "model_type": "bart", + "model_name_or_path": "sshleifer/bart-tiny-random", + "config_name": "", + "tokenizer_name": "", + "cache_dir": "", + "do_lower_case": False, + "learning_rate": 3e-05, + "weight_decay": 0.0, + "adam_epsilon": 1e-08, + "warmup_steps": 0, + "num_train_epochs": 1, + "train_batch_size": 2, + "eval_batch_size": 2, + "max_source_length": 12, + "max_target_length": 12, +} + def _dump_articles(path: Path, articles: list): with path.open("w") as f: f.write("\n".join(articles)) +def make_test_data_dir(): + tmp_dir = Path(tempfile.gettempdir()) + articles = [" Sam ate lunch today", "Sams lunch ingredients"] + summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] + for split in ["train", "val", "test"]: + _dump_articles((tmp_dir / f"{split}.source"), articles) + _dump_articles((tmp_dir / f"{split}.target"), summaries) + return tmp_dir + + class TestBartExamples(unittest.TestCase): - def test_bart_cnn_cli(self): + @classmethod + def setUpClass(cls): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) + logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks + return cls + + def test_bart_cnn_cli(self): tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] @@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase): testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] with patch.object(sys, "argv", testargs): run_generate() - self.assertTrue(output_file_name.exists()) + self.assertTrue(Path(output_file_name).exists()) + os.remove(Path(output_file_name)) + + def test_bart_run_sum_cli(self): + args_d: dict = DEFAULT_ARGS.copy() + tmp_dir = make_test_data_dir() + output_dir = tempfile.mkdtemp(prefix="output_") + args_d.update( + data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir, + ) + + args = argparse.Namespace(**args_d) + main(args) def test_bart_summarization_dataset(self): tmp_dir = Path(tempfile.gettempdir()) diff --git a/examples/transformer_base.py b/examples/transformer_base.py index a744c91594..a3b81610ea 100644 --- a/examples/transformer_base.py +++ b/examples/transformer_base.py @@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule): self.lr_scheduler.step() def get_tqdm_dict(self): - tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} - + avg_loss = getattr(self.trainer, "avg_loss", 0.0) + tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} return tqdm_dict def test_step(self, batch, batch_nb):