Update doc for metric_for_best_model when save_strategy="best". (#35389)
* Updated docstring for _determine_best_metric. * Updated docstring for metric_for_best_model. * Added test case for save strategy. * Updated incorrect test case. * Changed eval_strategy to match save_strategy. * Separated test cases for metric. * Allow load_best_model when save_strategy == "best". * Updated docstring for metric_for_best_model.
This commit is contained in:
committed by
GitHub
parent
29e74b7cbc
commit
88e18b3c63
@@ -4220,7 +4220,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Case 3: Metric name not provided; throw error.
|
||||
def test_metric_for_best_model_behavior(self):
|
||||
# Case 1: Metric name not provided when `save_strategy == "best"`.
|
||||
# Should raise ValueError.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
trainer = get_regression_trainer(
|
||||
@@ -4232,9 +4234,22 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
save_strategy="best",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
|
||||
self.assertIn("`args.metric_for_best_model` must be provided", str(context.exception))
|
||||
|
||||
# Case 2: Metric name not provided when `load_best_model_at_end == True`.
|
||||
# `metric_for_best_model` should be set to `"loss"` by default.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_strategy="steps",
|
||||
save_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user