[examples/seq2seq]: add --label_smoothing option (#5919)

This commit is contained in:
Sam Shleifer
2020-07-21 16:51:39 -04:00
committed by GitHub
parent 95d1962b9c
commit 5b193b39b0
7 changed files with 132 additions and 46 deletions

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
@@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback
def get_early_stopping_callback(metric, patience):
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)