Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=False,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
@@ -355,6 +355,8 @@ def generic_train(
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if early_stopping_callback:
|
||||
extra_callbacks.append(early_stopping_callback)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
@@ -376,7 +378,6 @@ def generic_train(
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
early_stop_callback=early_stopping_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user