[examples/seq2seq]: add --label_smoothing option (#5919)
This commit is contained in:
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
@@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)
|
||||
|
||||
Reference in New Issue
Block a user