Lightning Updates for v0.8.5 (#5798)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger == "default"
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger == "wandb":
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
|
||||
elif args.logger == "wandb_shared":
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user