[s2s] support early stopping based on loss, rather than rouge (#6927)
This commit is contained in:
@@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"label_smoothing": 0.2,
|
||||
"eval_beams": 1,
|
||||
"val_metric": None,
|
||||
"val_metric": "loss",
|
||||
"save_top_k": 1,
|
||||
"adafactor": True,
|
||||
"early_stopping_patience": 2,
|
||||
@@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
self.assertIn(ckpt_name, contents)
|
||||
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||
assert len(ckpt_files) > 0
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
Reference in New Issue
Block a user