[s2s] allow task_specific_params=summarization_xsum (#6923)
This commit is contained in:
@@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
|||||||
return self._write_logs(trainer, pl_module, "test")
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_callback(output_dir, metric):
|
def get_checkpoint_callback(output_dir, metric, save_top_k=1):
|
||||||
"""Saves the best model by validation ROUGE2 score."""
|
"""Saves the best model by validation ROUGE2 score."""
|
||||||
if metric == "rouge2":
|
if metric == "rouge2":
|
||||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||||
@@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
filepath=os.path.join(output_dir, exp),
|
filepath=os.path.join(output_dir, exp),
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="max",
|
||||||
save_top_k=1,
|
save_top_k=save_top_k,
|
||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|||||||
@@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
||||||
parser.add_argument("--val_metric", type=str, default=None, required=False)
|
parser.add_argument("--val_metric", type=str, default=None, required=False)
|
||||||
|
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stopping_patience",
|
"--early_stopping_patience",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
if model is None:
|
if model is None:
|
||||||
if args.task == "summarization":
|
if "summarization" in args.task:
|
||||||
model: SummarizationModule = SummarizationModule(args)
|
model: SummarizationModule = SummarizationModule(args)
|
||||||
else:
|
else:
|
||||||
model: SummarizationModule = TranslationModule(args)
|
model: SummarizationModule = TranslationModule(args)
|
||||||
@@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
logging_callback=Seq2SeqLoggingCallback(),
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k),
|
||||||
early_stopping_callback=es_callback,
|
early_stopping_callback=es_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
# TODO: early stopping callback seems messed up
|
# TODO: early stopping callback seems messed up
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
|
|||||||
"label_smoothing": 0.2,
|
"label_smoothing": 0.2,
|
||||||
"eval_beams": 1,
|
"eval_beams": 1,
|
||||||
"val_metric": None,
|
"val_metric": None,
|
||||||
|
"save_top_k": 1,
|
||||||
"adafactor": True,
|
"adafactor": True,
|
||||||
"early_stopping_patience": 2,
|
"early_stopping_patience": 2,
|
||||||
"logger_name": "default",
|
"logger_name": "default",
|
||||||
|
|||||||
Reference in New Issue
Block a user