[s2s] allow task_specific_params=summarization_xsum (#6923)
This commit is contained in:
@@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
||||
parser.add_argument("--val_metric", type=str, default=None, required=False)
|
||||
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
@@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule:
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if model is None:
|
||||
if args.task == "summarization":
|
||||
if "summarization" in args.task:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
@@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
|
||||
Reference in New Issue
Block a user