[examples] fix summarization do_predict (#3866)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -38,7 +39,7 @@ MODEL_MODES = {
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
def set_seed(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -47,7 +48,7 @@ def set_seed(args):
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
|
||||
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
||||
"Initialize a model."
|
||||
|
||||
super(BaseTransformer, self).__init__()
|
||||
@@ -192,7 +193,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Validation results *****")
|
||||
if pl_module.is_logger():
|
||||
metrics = trainer.callback_metrics
|
||||
@@ -201,7 +202,7 @@ class LoggingCallback(pl.Callback):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
def on_test_end(self, trainer, pl_module):
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Test results *****")
|
||||
|
||||
if pl_module.is_logger():
|
||||
@@ -256,7 +257,7 @@ def add_generic_args(parser, root_dir):
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
|
||||
def generic_train(model, args):
|
||||
def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
# init model
|
||||
set_seed(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user