Introduce Stateful Callbacks (#29666)
* Introduce saveable callbacks * Add note * Test for non-present and flag * Support early stopping and refusing to train further * Update docstring * More saving * Import oopsie * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Make it go through TrainerArguments * Document * Fix test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Rework to allow for duplicates * CLean * Fix failing tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -19,28 +21,44 @@ from unittest.mock import patch
|
||||
|
||||
from transformers import (
|
||||
DefaultFlowCallback,
|
||||
EarlyStoppingCallback,
|
||||
IntervalStrategy,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.trainer_callback import ExportableState
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
|
||||
|
||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||
|
||||
|
||||
class MyTestExportableCallback(TrainerCallback, ExportableState):
|
||||
def __init__(self, my_test_state="test"):
|
||||
self.my_test_state = my_test_state
|
||||
|
||||
def state(self):
|
||||
return {
|
||||
"args": {
|
||||
"my_test_state": self.my_test_state,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MyTestTrainerCallback(TrainerCallback):
|
||||
"A callback that registers the events that goes through."
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, my_test_state="test"):
|
||||
self.events = []
|
||||
self.my_test_state = my_test_state
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_init_end")
|
||||
@@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||
)
|
||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||
|
||||
def test_stateful_callbacks(self):
|
||||
# Use something with non-defaults
|
||||
cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[cb],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[EarlyStoppingCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cb = [
|
||||
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
|
||||
][0]
|
||||
assert cb.early_stopping_patience == 5
|
||||
assert cb.early_stopping_threshold == 0.2
|
||||
|
||||
def test_stateful_mixed_callbacks(self):
|
||||
# Use two callbacks, one stateful one not
|
||||
# Use something with non-defaults
|
||||
cbs = [
|
||||
MyTestTrainerCallback(my_test_state="another value"),
|
||||
EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
|
||||
]
|
||||
trainer = self.get_trainer(
|
||||
callbacks=cbs,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cbs = [
|
||||
callback
|
||||
for callback in trainer.callback_handler.callbacks
|
||||
if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
|
||||
]
|
||||
assert len(cbs) == 2
|
||||
my_test, early_stopping = cbs
|
||||
assert early_stopping.early_stopping_patience == 5
|
||||
assert early_stopping.early_stopping_threshold == 0.2
|
||||
assert my_test.my_test_state == "test"
|
||||
|
||||
def test_stateful_duplicate_callbacks(self):
|
||||
# Use something with non-defaults
|
||||
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
|
||||
trainer = self.get_trainer(
|
||||
callbacks=cbs,
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
cbs = [
|
||||
callback
|
||||
for callback in trainer.callback_handler.callbacks
|
||||
if isinstance(callback, MyTestExportableCallback)
|
||||
]
|
||||
assert len(cbs) == 2
|
||||
assert cbs[0].my_test_state == "first"
|
||||
assert cbs[1].my_test_state == "second"
|
||||
|
||||
def test_missing_stateful_callback(self):
|
||||
cb = EarlyStoppingCallback()
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[cb],
|
||||
load_best_model_at_end=True,
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Create a new trainer with defaults
|
||||
trainer = self.get_trainer(
|
||||
save_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
save_steps=2,
|
||||
eval_steps=2,
|
||||
max_steps=2,
|
||||
restore_callback_states_from_checkpoint=True,
|
||||
)
|
||||
# Load it back in and verify values
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
# warning should be emitted for not-present callbacks
|
||||
with patch("transformers.trainer.logger.warning") as warn_mock:
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]
|
||||
|
||||
def test_stateful_control(self):
|
||||
trainer = self.get_trainer(
|
||||
max_steps=2,
|
||||
save_strategy="steps",
|
||||
save_steps=2,
|
||||
)
|
||||
trainer.train()
|
||||
# Load it back in and verify values
|
||||
trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
|
||||
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
||||
trainer._load_callback_state()
|
||||
assert trainer.control.should_training_stop
|
||||
|
||||
Reference in New Issue
Block a user