Upgrade PyTorch Lightning to 1.0.2 (#7852)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sean Naren
2020-10-28 18:59:14 +00:00
committed by GitHub
parent 1b6c8d4811
commit 5e24982e58
8 changed files with 11 additions and 13 deletions

View File

@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=False,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
@@ -355,6 +355,8 @@ def generic_train(
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if early_stopping_callback:
extra_callbacks.append(early_stopping_callback)
if logging_callback is None:
logging_callback = LoggingCallback()
@@ -376,7 +378,6 @@ def generic_train(
callbacks=[logging_callback] + extra_callbacks,
logger=logger,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping_callback,
**train_params,
)