Keras callback to push to hub each epoch, or after N steps (#13773)
* Keras callback to push to hub each epoch, or after N steps * Reworked the callback to use Repository * Use an Enum for save_strategy * Style pass * Correct type for tokenizer * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding print message to the final upload * Adding print message to the final upload * Change how we wait for the last process to finish * is_done is a property, not a method, derp * Docstrings and documentation * Style pass * Style edit * Docstring reformat * Docstring rewrite * Replacing print with internal logger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -533,6 +533,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
main_classes/callback
|
main_classes/callback
|
||||||
main_classes/configuration
|
main_classes/configuration
|
||||||
main_classes/data_collator
|
main_classes/data_collator
|
||||||
|
main_classes/keras_callbacks
|
||||||
main_classes/logging
|
main_classes/logging
|
||||||
main_classes/model
|
main_classes/model
|
||||||
main_classes/optimizer_schedules
|
main_classes/optimizer_schedules
|
||||||
|
|||||||
22
docs/source/main_classes/keras_callbacks.rst
Normal file
22
docs/source/main_classes/keras_callbacks.rst
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
..
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
Keras callbacks
|
||||||
|
=======================================================================================================================
|
||||||
|
|
||||||
|
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
|
||||||
|
tasks:
|
||||||
|
|
||||||
|
PushToHubCallback
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
.. autoclass:: transformers.keras_callbacks.PushToHubCallback
|
||||||
97
src/transformers/keras_callbacks.py
Normal file
97
src/transformers/keras_callbacks.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from tensorflow.keras.callbacks import Callback
|
||||||
|
|
||||||
|
from huggingface_hub import Repository
|
||||||
|
|
||||||
|
from . import IntervalStrategy, PreTrainedTokenizerBase
|
||||||
|
from .file_utils import get_full_repo_name
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PushToHubCallback(Callback):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dir: Union[str, Path],
|
||||||
|
save_strategy: Union[str, IntervalStrategy] = "epoch",
|
||||||
|
save_steps: Optional[int] = None,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
|
hub_model_id: Optional[str] = None,
|
||||||
|
hub_token: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
output_dir (:obj:`str`):
|
||||||
|
The output directory where the model predictions and checkpoints will be written and synced with the
|
||||||
|
repository on the Hub.
|
||||||
|
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`):
|
||||||
|
The checkpoint save strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
|
* :obj:`"no"`: No save is done during training.
|
||||||
|
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||||
|
* :obj:`"steps"`: Save is done every :obj:`save_steps`
|
||||||
|
save_steps (:obj:`int`, `optional`):
|
||||||
|
The number of steps between saves when using the "steps" save_strategy.
|
||||||
|
tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`):
|
||||||
|
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
|
||||||
|
hub_model_id (:obj:`str`, `optional`):
|
||||||
|
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
|
||||||
|
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
|
||||||
|
of with :obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with
|
||||||
|
`output_dir_name` being the name of :obj:`output_dir`.
|
||||||
|
hub_token (:obj:`str`, `optional`):
|
||||||
|
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||||
|
:obj:`huggingface-cli login`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(save_strategy, str):
|
||||||
|
save_strategy = IntervalStrategy(save_strategy.lower())
|
||||||
|
self.save_strategy = save_strategy
|
||||||
|
if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
|
||||||
|
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
|
||||||
|
self.save_steps = save_steps
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
if hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(output_dir.absolute().name, token=hub_token)
|
||||||
|
else:
|
||||||
|
repo_name = hub_model_id
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.repo = Repository(str(output_dir), clone_from=repo_name)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.last_job = None
|
||||||
|
|
||||||
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
|
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
|
||||||
|
if self.last_job is not None and not self.last_job.is_done:
|
||||||
|
return # The last upload is still running, don't start another
|
||||||
|
self.model.save_pretrained(self.output_dir)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(self.output_dir)
|
||||||
|
_, self.last_job = self.repo.push_to_hub(
|
||||||
|
commit_message=f"Training in progress steps {batch}", blocking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
if self.save_strategy == IntervalStrategy.EPOCH:
|
||||||
|
if self.last_job is not None and not self.last_job.is_done:
|
||||||
|
return # The last upload is still running, don't start another
|
||||||
|
self.model.save_pretrained(self.output_dir)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(self.output_dir)
|
||||||
|
_, self.last_job = self.repo.push_to_hub(
|
||||||
|
commit_message=f"Training in progress epoch {epoch}", blocking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_train_end(self, logs=None):
|
||||||
|
if self.last_job is not None and not self.last_job.is_done:
|
||||||
|
logger.info("Waiting for existing upload to finish...")
|
||||||
|
while not self.last_job.is_done:
|
||||||
|
sleep(1)
|
||||||
|
self.model.save_pretrained(self.output_dir)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(self.output_dir)
|
||||||
|
self.repo.push_to_hub(commit_message="End of training", blocking=True)
|
||||||
Reference in New Issue
Block a user