Introduce Stateful Callbacks (#29666)
* Introduce saveable callbacks * Add note * Test for non-present and flag * Support early stopping and refusing to train further * Update docstring * More saving * Import oopsie * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Make it go through TrainerArguments * Document * Fix test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Rework to allow for duplicates * CLean * Fix failing tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -78,6 +78,7 @@ from .tokenization_utils_base import PreTrainedTokenizerBase
|
|||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
DefaultFlowCallback,
|
DefaultFlowCallback,
|
||||||
|
ExportableState,
|
||||||
PrinterCallback,
|
PrinterCallback,
|
||||||
ProgressCallback,
|
ProgressCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
@@ -649,12 +650,15 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
self.label_smoother = None
|
self.label_smoother = None
|
||||||
|
|
||||||
|
self.control = TrainerControl()
|
||||||
|
|
||||||
self.state = TrainerState(
|
self.state = TrainerState(
|
||||||
is_local_process_zero=self.is_local_process_zero(),
|
is_local_process_zero=self.is_local_process_zero(),
|
||||||
is_world_process_zero=self.is_world_process_zero(),
|
is_world_process_zero=self.is_world_process_zero(),
|
||||||
|
stateful_callbacks=[
|
||||||
|
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.control = TrainerControl()
|
|
||||||
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
|
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
|
||||||
# returned to 0 every time flos need to be logged
|
# returned to 0 every time flos need to be logged
|
||||||
self.current_flos = 0
|
self.current_flos = 0
|
||||||
@@ -1499,6 +1503,8 @@ class Trainer:
|
|||||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
self.save_model(output_dir, _internal_call=True)
|
self.save_model(output_dir, _internal_call=True)
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
|
# Update the `TrainerControl` state to where we are currently
|
||||||
|
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
|
||||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
@@ -1970,7 +1976,11 @@ class Trainer:
|
|||||||
if not delay_optimizer_creation:
|
if not delay_optimizer_creation:
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState(
|
||||||
|
stateful_callbacks=[
|
||||||
|
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.state.is_hyper_param_search = trial is not None
|
self.state.is_hyper_param_search = trial is not None
|
||||||
self.state.train_batch_size = self._train_batch_size
|
self.state.train_batch_size = self._train_batch_size
|
||||||
|
|
||||||
@@ -2079,6 +2089,7 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
||||||
|
self._load_callback_state()
|
||||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||||
if not args.ignore_data_skip:
|
if not args.ignore_data_skip:
|
||||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||||
@@ -2786,6 +2797,8 @@ class Trainer:
|
|||||||
|
|
||||||
# Save the Trainer state
|
# Save the Trainer state
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
|
# Update the `TrainerControl` state to where we are currently
|
||||||
|
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
|
||||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||||
|
|
||||||
if self.args.push_to_hub:
|
if self.args.push_to_hub:
|
||||||
@@ -2970,6 +2983,45 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
def _load_callback_state(self):
|
||||||
|
"""If callback states exist and were passed in, restore their states if enabled"""
|
||||||
|
if not self.args.restore_callback_states_from_checkpoint:
|
||||||
|
return
|
||||||
|
# Callback states are stored in stateful_callbacks
|
||||||
|
not_found = []
|
||||||
|
new_callbacks = []
|
||||||
|
original_callbacks = self.callback_handler.callbacks + [self.control]
|
||||||
|
for stored_callback, data in self.state.stateful_callbacks.items():
|
||||||
|
if not isinstance(data, list):
|
||||||
|
data = [data]
|
||||||
|
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
|
||||||
|
# We can load/restore from multiple callbacks of the same type.
|
||||||
|
duplicates = [
|
||||||
|
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
|
||||||
|
]
|
||||||
|
for callback, callback_data in zip(duplicates, data):
|
||||||
|
args = callback_data.get("args", {})
|
||||||
|
attributes = callback_data.get("attributes", {})
|
||||||
|
new_callback = type(callback)(**args)
|
||||||
|
for attribute, value in attributes.items():
|
||||||
|
setattr(new_callback, attribute, value)
|
||||||
|
if isinstance(callback, TrainerControl):
|
||||||
|
# Specifically for restoring the `control` state
|
||||||
|
self.control = new_callback
|
||||||
|
else:
|
||||||
|
new_callbacks.append(new_callback)
|
||||||
|
# We remove the existing callback and add it to the list of new callbacks
|
||||||
|
self.callback_handler.remove_callback(type(new_callback))
|
||||||
|
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
|
||||||
|
else:
|
||||||
|
not_found.append(stored_callback)
|
||||||
|
if len(not_found) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
|
||||||
|
)
|
||||||
|
for callback in new_callbacks:
|
||||||
|
self.callback_handler.add_callback(callback)
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class TrainerState:
|
|||||||
is_hyper_param_search (`bool`, *optional*, defaults to `False`):
|
is_hyper_param_search (`bool`, *optional*, defaults to `False`):
|
||||||
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
|
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
|
||||||
impact the way data will be logged in TensorBoard.
|
impact the way data will be logged in TensorBoard.
|
||||||
|
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
|
||||||
|
Callbacks attached to the `Trainer` that should have their states be saved or restored.
|
||||||
|
Relevent callbacks should implement a `state` and `from_state` function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
epoch: Optional[float] = None
|
epoch: Optional[float] = None
|
||||||
@@ -104,10 +107,34 @@ class TrainerState:
|
|||||||
is_hyper_param_search: bool = False
|
is_hyper_param_search: bool = False
|
||||||
trial_name: str = None
|
trial_name: str = None
|
||||||
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
||||||
|
stateful_callbacks: List["TrainerCallback"] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.log_history is None:
|
if self.log_history is None:
|
||||||
self.log_history = []
|
self.log_history = []
|
||||||
|
if self.stateful_callbacks is None:
|
||||||
|
self.stateful_callbacks = {}
|
||||||
|
elif isinstance(self.stateful_callbacks, dict):
|
||||||
|
# We are loading the callbacks in from the state file, no need to process them
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Saveable callbacks get stored as dict of kwargs
|
||||||
|
stateful_callbacks = {}
|
||||||
|
for callback in self.stateful_callbacks:
|
||||||
|
if not isinstance(callback, (ExportableState)):
|
||||||
|
raise TypeError(
|
||||||
|
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
|
||||||
|
)
|
||||||
|
name = callback.__class__.__name__
|
||||||
|
if name in stateful_callbacks:
|
||||||
|
# We can have multiple versions of the same callback
|
||||||
|
# if so, we store them as a list of states to restore
|
||||||
|
if not isinstance(stateful_callbacks[name], list):
|
||||||
|
stateful_callbacks[name] = [stateful_callbacks[name]]
|
||||||
|
stateful_callbacks[name].append(callback.state())
|
||||||
|
else:
|
||||||
|
stateful_callbacks[name] = callback.state()
|
||||||
|
self.stateful_callbacks = stateful_callbacks
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
"""Save the content of this instance in JSON format inside `json_path`."""
|
"""Save the content of this instance in JSON format inside `json_path`."""
|
||||||
@@ -123,8 +150,52 @@ class TrainerState:
|
|||||||
return cls(**json.loads(text))
|
return cls(**json.loads(text))
|
||||||
|
|
||||||
|
|
||||||
|
class ExportableState:
|
||||||
|
"""
|
||||||
|
A class for objects that include the ability to have its state
|
||||||
|
be saved during `Trainer._save_checkpoint` and loaded back in during
|
||||||
|
`Trainer._load_from_checkpoint`.
|
||||||
|
|
||||||
|
These must implement a `state` function that gets called during the respective
|
||||||
|
Trainer function call. It should only include parameters and attributes needed to
|
||||||
|
recreate the state at a particular time, to avoid utilizing pickle/maintain standard
|
||||||
|
file IO writing.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class EarlyStoppingCallback(TrainerCallback, ExportableState):
|
||||||
|
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
|
||||||
|
self.early_stopping_patience = early_stopping_patience
|
||||||
|
self.early_stopping_threshold = early_stopping_threshold
|
||||||
|
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
|
||||||
|
self.early_stopping_patience_counter = 0
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
return {
|
||||||
|
"args": {
|
||||||
|
"early_stopping_patience": self.early_stopping_patience,
|
||||||
|
"early_stopping_threshold": self.early_stopping_threshold,
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"early_stopping_patience_counter": self.early_stopping_patience_counter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
raise NotImplementedError("You must implement a `state` function to utilize this class.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state):
|
||||||
|
instance = cls(**state["args"])
|
||||||
|
for k, v in state["attributes"].items():
|
||||||
|
setattr(instance, k, v)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainerControl:
|
class TrainerControl(ExportableState):
|
||||||
"""
|
"""
|
||||||
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
|
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
|
||||||
switches in the training loop.
|
switches in the training loop.
|
||||||
@@ -172,6 +243,18 @@ class TrainerControl:
|
|||||||
self.should_evaluate = False
|
self.should_evaluate = False
|
||||||
self.should_log = False
|
self.should_log = False
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
return {
|
||||||
|
"args": {
|
||||||
|
"should_training_stop": self.should_training_stop,
|
||||||
|
"should_epoch_stop": self.should_epoch_stop,
|
||||||
|
"should_save": self.should_save,
|
||||||
|
"should_evaluate": self.should_evaluate,
|
||||||
|
"should_log": self.should_log,
|
||||||
|
},
|
||||||
|
"attributes": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TrainerCallback:
|
class TrainerCallback:
|
||||||
# no-format
|
# no-format
|
||||||
@@ -546,7 +629,7 @@ class PrinterCallback(TrainerCallback):
|
|||||||
print(logs)
|
print(logs)
|
||||||
|
|
||||||
|
|
||||||
class EarlyStoppingCallback(TrainerCallback):
|
class EarlyStoppingCallback(TrainerCallback, ExportableState):
|
||||||
"""
|
"""
|
||||||
A [`TrainerCallback`] that handles early stopping.
|
A [`TrainerCallback`] that handles early stopping.
|
||||||
|
|
||||||
@@ -605,3 +688,14 @@ class EarlyStoppingCallback(TrainerCallback):
|
|||||||
self.check_metric_value(args, state, control, metric_value)
|
self.check_metric_value(args, state, control, metric_value)
|
||||||
if self.early_stopping_patience_counter >= self.early_stopping_patience:
|
if self.early_stopping_patience_counter >= self.early_stopping_patience:
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
return {
|
||||||
|
"args": {
|
||||||
|
"early_stopping_patience": self.early_stopping_patience,
|
||||||
|
"early_stopping_threshold": self.early_stopping_threshold,
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"early_stopping_patience_counter": self.early_stopping_patience_counter,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|||||||
@@ -357,6 +357,9 @@ class TrainingArguments:
|
|||||||
Note that when this is true, you won't be able to resume training from checkpoint.
|
Note that when this is true, you won't be able to resume training from checkpoint.
|
||||||
This enables you to save storage by not storing the optimizer, scheduler & rng state.
|
This enables you to save storage by not storing the optimizer, scheduler & rng state.
|
||||||
You can only load the model using `from_pretrained` with this option set to `True`.
|
You can only load the model using `from_pretrained` with this option set to `True`.
|
||||||
|
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to restore the callback states from the checkpoint. If `True`, will override
|
||||||
|
callbacks passed to the `Trainer` if they exist in the checkpoint."
|
||||||
use_cpu (`bool`, *optional*, defaults to `False`):
|
use_cpu (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
|
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
|
||||||
seed (`int`, *optional*, defaults to 42):
|
seed (`int`, *optional*, defaults to 42):
|
||||||
@@ -951,6 +954,12 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
restore_callback_states_from_checkpoint: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
|
||||||
|
},
|
||||||
|
)
|
||||||
no_cuda: bool = field(
|
no_cuda: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
|
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -19,28 +21,44 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DefaultFlowCallback,
|
DefaultFlowCallback,
|
||||||
|
EarlyStoppingCallback,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PrinterCallback,
|
PrinterCallback,
|
||||||
ProgressCallback,
|
ProgressCallback,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
|
TrainerState,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import require_torch
|
||||||
|
from transformers.trainer_callback import ExportableState
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.trainer import DEFAULT_CALLBACKS
|
from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
|
||||||
|
|
||||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
class MyTestExportableCallback(TrainerCallback, ExportableState):
|
||||||
|
def __init__(self, my_test_state="test"):
|
||||||
|
self.my_test_state = my_test_state
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
return {
|
||||||
|
"args": {
|
||||||
|
"my_test_state": self.my_test_state,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MyTestTrainerCallback(TrainerCallback):
|
class MyTestTrainerCallback(TrainerCallback):
|
||||||
"A callback that registers the events that goes through."
|
"A callback that registers the events that goes through."
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, my_test_state="test"):
|
||||||
self.events = []
|
self.events = []
|
||||||
|
self.my_test_state = my_test_state
|
||||||
|
|
||||||
def on_init_end(self, args, state, control, **kwargs):
|
def on_init_end(self, args, state, control, **kwargs):
|
||||||
self.events.append("on_init_end")
|
self.events.append("on_init_end")
|
||||||
@@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||||
)
|
)
|
||||||
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
|
||||||
|
|
||||||
|
def test_stateful_callbacks(self):
|
||||||
|
# Use something with non-defaults
|
||||||
|
cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[cb],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Create a new trainer with defaults
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[EarlyStoppingCallback()],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
restore_callback_states_from_checkpoint=True,
|
||||||
|
)
|
||||||
|
# Load it back in and verify values
|
||||||
|
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
cb = [
|
||||||
|
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
|
||||||
|
][0]
|
||||||
|
assert cb.early_stopping_patience == 5
|
||||||
|
assert cb.early_stopping_threshold == 0.2
|
||||||
|
|
||||||
|
def test_stateful_mixed_callbacks(self):
|
||||||
|
# Use two callbacks, one stateful one not
|
||||||
|
# Use something with non-defaults
|
||||||
|
cbs = [
|
||||||
|
MyTestTrainerCallback(my_test_state="another value"),
|
||||||
|
EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
|
||||||
|
]
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=cbs,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Create a new trainer with defaults
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
restore_callback_states_from_checkpoint=True,
|
||||||
|
)
|
||||||
|
# Load it back in and verify values
|
||||||
|
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
cbs = [
|
||||||
|
callback
|
||||||
|
for callback in trainer.callback_handler.callbacks
|
||||||
|
if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
|
||||||
|
]
|
||||||
|
assert len(cbs) == 2
|
||||||
|
my_test, early_stopping = cbs
|
||||||
|
assert early_stopping.early_stopping_patience == 5
|
||||||
|
assert early_stopping.early_stopping_threshold == 0.2
|
||||||
|
assert my_test.my_test_state == "test"
|
||||||
|
|
||||||
|
def test_stateful_duplicate_callbacks(self):
|
||||||
|
# Use something with non-defaults
|
||||||
|
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=cbs,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Create a new trainer with defaults
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
restore_callback_states_from_checkpoint=True,
|
||||||
|
)
|
||||||
|
# Load it back in and verify values
|
||||||
|
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
cbs = [
|
||||||
|
callback
|
||||||
|
for callback in trainer.callback_handler.callbacks
|
||||||
|
if isinstance(callback, MyTestExportableCallback)
|
||||||
|
]
|
||||||
|
assert len(cbs) == 2
|
||||||
|
assert cbs[0].my_test_state == "first"
|
||||||
|
assert cbs[1].my_test_state == "second"
|
||||||
|
|
||||||
|
def test_missing_stateful_callback(self):
|
||||||
|
cb = EarlyStoppingCallback()
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[cb],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Create a new trainer with defaults
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
eval_steps=2,
|
||||||
|
max_steps=2,
|
||||||
|
restore_callback_states_from_checkpoint=True,
|
||||||
|
)
|
||||||
|
# Load it back in and verify values
|
||||||
|
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||||
|
# warning should be emitted for not-present callbacks
|
||||||
|
with patch("transformers.trainer.logger.warning") as warn_mock:
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]
|
||||||
|
|
||||||
|
def test_stateful_control(self):
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
max_steps=2,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=2,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
# Load it back in and verify values
|
||||||
|
trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
|
||||||
|
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
|
||||||
|
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
|
||||||
|
trainer._load_callback_state()
|
||||||
|
assert trainer.control.should_training_stop
|
||||||
|
|||||||
Reference in New Issue
Block a user