From f5c2a122e34836b87abb6042cf641b040e790e1c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 22 Jun 2020 20:40:10 -0400 Subject: [PATCH] Upgrade examples to pl=0.8.1(#5146) --- examples/lightning_base.py | 63 +++++-------- examples/requirements.txt | 2 +- examples/summarization/callbacks.py | 3 +- examples/summarization/distillation.py | 1 + examples/summarization/finetune.py | 29 ++---- examples/summarization/run_distiller.sh | 1 - examples/summarization/run_eval.py | 1 + .../test_summarization_examples.py | 89 ++++--------------- examples/summarization/utils.py | 2 + examples/text-classification/run_pl_glue.py | 10 +-- src/transformers/tokenization_auto.py | 2 +- 11 files changed, 53 insertions(+), 150 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 39604efae3..2574aa9458 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -8,6 +8,7 @@ from typing import Any, Dict import numpy as np import pytorch_lightning as pl import torch +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from transformers import ( AdamW, @@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule): model=None, **config_kwargs ): - "Initialize a model." - + """Initialize a model, tokenizer and config.""" super().__init__() - self.hparams = hparams + self.hparams = hparams # TODO: move to self.save_hyperparameters() self.step_count = 0 self.tfmr_ckpts = {} self.output_dir = Path(self.hparams.output_dir) @@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule): ) else: self.tokenizer: PreTrainedTokenizer = tokenizer + self.model_type = MODEL_MODES[mode] if model is None: - self.model_type = MODEL_MODES[mode] self.model = self.model_type.from_pretrained( self.hparams.model_name_or_path, from_tf=bool(".ckpt" in self.hparams.model_name_or_path), @@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule): cache_dir=cache_dir, ) else: - self.model_type = None self.model = model def load_hf_checkpoint(self, *args, **kwargs): self.model = self.model_type.from_pretrained(*args, **kwargs) - def is_logger(self): - return self.trainer.proc_rank <= 0 - def configure_optimizers(self): "Prepare optimizer and schedule (linear warmup and decay)" - model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ @@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule): self.opt = optimizer return [optimizer] - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): - if self.trainer.use_tpu: - xm.optimizer_step(optimizer) - else: - optimizer.step() - optimizer.zero_grad() - self.lr_scheduler.step() - - def get_tqdm_dict(self): - avg_loss = getattr(self.trainer, "avg_loss", 0.0) - tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} - return tqdm_dict - def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) - def test_end(self, outputs): + def test_epoch_end(self, outputs): return self.validation_end(outputs) def train_dataloader(self): @@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule): parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") parser.add_argument( "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." ) @@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule): class LoggingCallback(pl.Callback): + @rank_zero_only def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - logger.info("***** Validation results *****") - if pl_module.is_logger(): - metrics = trainer.callback_metrics - # Log results + rank_zero_info("***** Validation results *****") + metrics = trainer.callback_metrics + # Log results + for key in sorted(metrics): + if key not in ["log", "progress_bar"]: + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) + + @rank_zero_only + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + logger.info("***** Test results *****") + metrics = trainer.callback_metrics + # Log and save results to file + output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") + with open(output_test_results_file, "w") as writer: for key in sorted(metrics): if key not in ["log", "progress_bar"]: logger.info("{} = {}\n".format(key, str(metrics[key]))) - - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - logger.info("***** Test results *****") - - if pl_module.is_logger(): - metrics = trainer.callback_metrics - - # Log and save results to file - output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") - with open(output_test_results_file, "w") as writer: - for key in sorted(metrics): - if key not in ["log", "progress_bar"]: - logger.info("{} = {}\n".format(key, str(metrics[key]))) - writer.write("{} = {}\n".format(key, str(metrics[key]))) + writer.write("{} = {}\n".format(key, str(metrics[key]))) def add_generic_args(parser, root_dir) -> None: diff --git a/examples/requirements.txt b/examples/requirements.txt index daf2081fe9..6ab5c2c05a 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,7 +5,7 @@ psutil sacrebleu rouge-score tensorflow_datasets -pytorch-lightning==0.7.6 +pytorch-lightning==0.8.1 matplotlib git-python==1.0.3 faiss diff --git a/examples/summarization/callbacks.py b/examples/summarization/callbacks.py index 6129d5f0b9..83b54d08c7 100644 --- a/examples/summarization/callbacks.py +++ b/examples/summarization/callbacks.py @@ -19,12 +19,11 @@ logger = logging.getLogger(__name__) class Seq2SeqLoggingCallback(pl.Callback): + @rank_zero_only def _write_logs( self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True ) -> None: logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") - if not pl_module.is_logger(): - return metrics = trainer.callback_metrics trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) # Log results diff --git a/examples/summarization/distillation.py b/examples/summarization/distillation.py index c9f5d5b04e..290dde0518 100644 --- a/examples/summarization/distillation.py +++ b/examples/summarization/distillation.py @@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule): class T5SummarizationDistiller(SummarizationDistiller): def pre_init(self, hparams): + raise NotImplementedError("T5 Distillation does not work yet") teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) n_layer = hparams.student_decoder_layers assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this diff --git a/examples/summarization/finetune.py b/examples/summarization/finetune.py index 56a1984635..f2e3f64637 100644 --- a/examples/summarization/finetune.py +++ b/examples/summarization/finetune.py @@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer): if self.hparams.freeze_encoder: freeze_params(self.model.model.encoder) # TODO: this will break for t5 self.hparams.git_sha = get_git_info()["repo_sha"] - self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu + self.num_workers = hparams.num_workers def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer): def validation_step(self, batch, batch_idx) -> Dict: return self._generative_step(batch) - def validation_end(self, outputs, prefix="val") -> Dict: + def validation_epoch_end(self, outputs, prefix="val") -> Dict: self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] @@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer): self.metrics[prefix].append(metrics) pickle_save(self.metrics, self.metrics_save_path) - def _generative_step(self, batch): + def _generative_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) - # TODO(SS): task specific params - t0 = time.time() generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) - gen_time = time.time() - t0 + gen_time = time.time() - t0 / source_ids.shape[0] preds = self.ids_to_clean_text(generated_ids) target = self.ids_to_clean_text(y) loss_tensors = self._step(batch) @@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer): def test_step(self, batch, batch_idx): return self._generative_step(batch) - def test_end(self, outputs): - return self.validation_end(outputs, prefix="test") - def test_epoch_end(self, outputs): - output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") - output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt") - # write predictions and targets for later rouge evaluation. - with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer: - for output_batch in outputs: - p_writer.writelines(s + "\n" for s in output_batch["preds"]) - t_writer.writelines(s + "\n" for s in output_batch["target"]) - p_writer.close() - t_writer.close() - - return self.test_end(outputs) - - def validation_epoch_end(self, outputs): - self.validation_end(outputs, "val") + return self.validation_epoch_end(outputs, prefix="test") def get_dataset(self, type_path) -> SummarizationDataset: n_obs = self.n_obs[type_path] @@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule: logger=logger, # TODO: early stopping callback seems messed up ) + pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model diff --git a/examples/summarization/run_distiller.sh b/examples/summarization/run_distiller.sh index 6fbecad388..a4d43de64a 100755 --- a/examples/summarization/run_distiller.sh +++ b/examples/summarization/run_distiller.sh @@ -7,6 +7,5 @@ python distillation.py \ --learning_rate=3e-4 \ --do_train \ --do_predict \ ---fp16 \ --val_check_interval 0.1 \ $@ diff --git a/examples/summarization/run_eval.py b/examples/summarization/run_eval.py index 3013743c89..0bbaf9d64a 100644 --- a/examples/summarization/run_eval.py +++ b/examples/summarization/run_eval.py @@ -26,6 +26,7 @@ def generate_summaries( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") + model_name = str(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) if fp16: model = model.half() diff --git a/examples/summarization/test_summarization_examples.py b/examples/summarization/test_summarization_examples.py index 9688d7f88a..d829793ce1 100644 --- a/examples/summarization/test_summarization_examples.py +++ b/examples/summarization/test_summarization_examples.py @@ -24,6 +24,7 @@ logger = logging.getLogger() FP16_EVER = False CHEAP_ARGS = { "logger": "default", + "num_workers": 2, "alpha_hid": 0, "freeze_embeds": True, "enc_only": False, @@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list): f.write("\n".join(articles)) -BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute() +MSG = "T5 is broken at the moment" +T5_TINY = "patrickvonplaten/t5-tiny-random" def make_test_data_dir(): @@ -92,7 +94,6 @@ def make_test_data_dir(): return tmp_dir -@unittest.skip("These wont' pass until hidden_states kwarg is merged.") class TestSummarizationDistiller(unittest.TestCase): @classmethod def setUpClass(cls): @@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase): freeze_encoder=True, gpus=2, sortish_sampler=False, - ) - self._bart_distiller_cli(updates) - - @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") - def test_bdc_fp16(self): - updates = dict( - student_encoder_layers=2, - student_decoder_layers=1, - alpha_hid=3.0, - freeze_encoder=True, - gpus=1, - fp16=FP16_EVER, fp16_opt_level="O1", + fp16=FP16_EVER, ) self._bart_distiller_cli(updates) - @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") - def test_bdc_t5_eval_fp16(self): + def test_bdc_t5_train(self): updates = dict( fp16=FP16_EVER, - gpus=1, + gpus=1 if torch.cuda.is_available() else 0, model_type="t5", - model_name_or_path="patrickvonplaten/t5-tiny-random", - do_train=False, - do_predict=True, - tokenizer_name=None, - no_teacher=True, - ) - self._bart_distiller_cli(updates, check_contents=False) - - @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") - def test_bdc_t5_train_fp16(self): - updates = dict( - fp16=FP16_EVER, - gpus=1, - model_type="t5", - model_name_or_path="patrickvonplaten/t5-tiny-random", + model_name_or_path=T5_TINY, do_train=True, do_predict=True, - tokenizer_name="patrickvonplaten/t5-tiny-random", + tokenizer_name=T5_TINY, no_teacher=True, + alpha_hid=2.0, ) self._bart_distiller_cli(updates) @@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase): self._bart_distiller_cli(updates) def test_bdc_checkpointing(self): - updates = dict( student_encoder_layers=2, student_decoder_layers=1, @@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase): evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) - def test_bdc_t5(self): - updates = dict( - student_encoder_layers=1, - student_decoder_layers=1, - alpha_hid=2.0, - teacher="patrickvonplaten/t5-tiny-random", - model_type="t5", - model_name_or_path="patrickvonplaten/t5-tiny-random", - tokenizer_name="patrickvonplaten/t5-tiny-random", - ) - self._bart_distiller_cli(updates) - - def test_bdc_t5_eval(self): - updates = dict( - model_type="t5", - model_name_or_path="patrickvonplaten/t5-tiny-random", - do_train=False, - do_predict=True, - tokenizer_name="patrickvonplaten/t5-tiny-random", - no_teacher=True, - ) - self._bart_distiller_cli(updates, check_contents=False) - def _bart_distiller_cli(self, updates, check_contents=True): default_updates = dict( - model_type="bart", train_batch_size=1, eval_batch_size=2, num_train_epochs=2, @@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase): self.assertIn(ckpt_name, contents) self.assertIn("metrics.pkl", contents) self.assertIn("test_generations.txt", contents) - self.assertIn("val_generations_1.txt", contents) - self.assertIn("val_1_results.txt", contents) + self.assertIn("val_generations_00001.txt", contents) + self.assertIn("val_results_00001.txt", contents) self.assertIn("test_results.txt", contents) - # self.assertEqual(len(contents), 15) metrics = pickle_load(Path(output_dir) / "metrics.pkl") - import pandas as pd - - val_df = pd.DataFrame(metrics["val"]) - train_df = pd.DataFrame(metrics["train"]) - test_df = pd.DataFrame(metrics["test"]) - desired_n_evals = args_d["num_train_epochs"] * 2 + 1 - self.assertEqual(val_df.shape[0], desired_n_evals) # - self.assertEqual(test_df.shape[1], val_df.shape[1]) - self.assertEqual(train_df.shape[0], 0) + desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) + self.assertEqual(len(metrics["val"]), desired_n_evals) + self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here return model @@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase): output_dir = tempfile.mkdtemp(prefix="output_") args_d.update( data_dir=tmp_dir, - model_type="t5", - model_name_or_path="patrickvonplaten/t5-tiny-random", - tokenizer_name=None, # "patrickvonplaten/t5-tiny-random", + model_name_or_path=T5_TINY, + tokenizer_name=None, # T5_TINY, train_batch_size=2, eval_batch_size=2, gpus=0, diff --git a/examples/summarization/utils.py b/examples/summarization/utils.py index d59e733ab0..bff6c3de3e 100644 --- a/examples/summarization/utils.py +++ b/examples/summarization/utils.py @@ -45,8 +45,10 @@ def encode_file( max_length=max_length, pad_to_max_length=pad_to_max_length, add_prefix_space=True, + truncation=True, return_tensors=return_tensors, ) + assert tokenized.input_ids.shape[1] == max_length examples.append(tokenized) torch.save(lmap(dict, examples), cache_path.open("wb")) return examples diff --git a/examples/text-classification/run_pl_glue.py b/examples/text-classification/run_pl_glue.py index 88e5912cad..19d8c913dd 100644 --- a/examples/text-classification/run_pl_glue.py +++ b/examples/text-classification/run_pl_glue.py @@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer): return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} - def _eval_end(self, outputs): + def _eval_end(self, outputs) -> tuple: val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item() preds = np.concatenate([x["pred"] for x in outputs], axis=0) @@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer): logs = ret["log"] return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs} - def test_epoch_end(self, outputs): - # updating to test_epoch_end instead of deprecated test_end + def test_epoch_end(self, outputs) -> dict: ret, predictions, targets = self._eval_end(outputs) - - # Converting to the dic required by pl - # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\ - # pytorch_lightning/trainer/logging.py#L139 logs = ret["log"] # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss` return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs} @staticmethod def add_model_specific_args(parser, root_dir): - # Add NER specific options BaseTransformer.add_model_specific_args(parser, root_dir) parser.add_argument( "--max_seq_length", diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 80152e85a0..7e2992c782 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -205,7 +205,7 @@ class AutoTokenizer: if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - if "bert-base-japanese" in pretrained_model_name_or_path: + if "bert-base-japanese" in str(pretrained_model_name_or_path): return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) use_fast = kwargs.pop("use_fast", False)