Lightning Updates for v0.8.5 (#5798)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Nathan Raw
2020-07-17 20:43:06 -06:00
committed by GitHub
parent 615be03f9d
commit 529850ae7b
7 changed files with 73 additions and 97 deletions

View File

@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser
@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger == "default"
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared":
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train(
model,
args,
@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)