From 3a8a8013adc5802cc302e63e403f0e8925fab0ea Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 29 Sep 2021 12:47:35 +0100 Subject: [PATCH] Keras callback to push to hub each epoch, or after N steps (#13773) * Keras callback to push to hub each epoch, or after N steps * Reworked the callback to use Repository * Use an Enum for save_strategy * Style pass * Correct type for tokenizer * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding print message to the final upload * Adding print message to the final upload * Change how we wait for the last process to finish * is_done is a property, not a method, derp * Docstrings and documentation * Style pass * Style edit * Docstring reformat * Docstring rewrite * Replacing print with internal logger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/index.rst | 1 + docs/source/main_classes/keras_callbacks.rst | 22 +++++ src/transformers/keras_callbacks.py | 97 ++++++++++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 docs/source/main_classes/keras_callbacks.rst create mode 100644 src/transformers/keras_callbacks.py diff --git a/docs/source/index.rst b/docs/source/index.rst index cd28423005..bd0ca334c6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -533,6 +533,7 @@ Flax), PyTorch, and/or TensorFlow. main_classes/callback main_classes/configuration main_classes/data_collator + main_classes/keras_callbacks main_classes/logging main_classes/model main_classes/optimizer_schedules diff --git a/docs/source/main_classes/keras_callbacks.rst b/docs/source/main_classes/keras_callbacks.rst new file mode 100644 index 0000000000..476802469b --- /dev/null +++ b/docs/source/main_classes/keras_callbacks.rst @@ -0,0 +1,22 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +Keras callbacks +======================================================================================================================= + +When training a Transformers model with Keras, there are some library-specific callbacks available to automate common +tasks: + +PushToHubCallback +----------------------------------------------------------------------------------------------------------------------- + +.. autoclass:: transformers.keras_callbacks.PushToHubCallback diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py new file mode 100644 index 0000000000..fc4f53f88b --- /dev/null +++ b/src/transformers/keras_callbacks.py @@ -0,0 +1,97 @@ +import logging +from pathlib import Path +from time import sleep +from typing import Optional, Union + +from tensorflow.keras.callbacks import Callback + +from huggingface_hub import Repository + +from . import IntervalStrategy, PreTrainedTokenizerBase +from .file_utils import get_full_repo_name + + +logger = logging.getLogger(__name__) + + +class PushToHubCallback(Callback): + def __init__( + self, + output_dir: Union[str, Path], + save_strategy: Union[str, IntervalStrategy] = "epoch", + save_steps: Optional[int] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + hub_model_id: Optional[str] = None, + hub_token: Optional[str] = None, + ): + """ + output_dir (:obj:`str`): + The output directory where the model predictions and checkpoints will be written and synced with the + repository on the Hub. + save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`): + The checkpoint save strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No save is done during training. + * :obj:`"epoch"`: Save is done at the end of each epoch. + * :obj:`"steps"`: Save is done every :obj:`save_steps` + save_steps (:obj:`int`, `optional`): + The number of steps between saves when using the "steps" save_strategy. + tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`): + The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. + hub_model_id (:obj:`str`, `optional`): + The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository + name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member + of with :obj:`"organization_name/model"`. Will default to :obj:`user_name/output_dir_name` with + `output_dir_name` being the name of :obj:`output_dir`. + hub_token (:obj:`str`, `optional`): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + :obj:`huggingface-cli login`. + """ + super().__init__() + if isinstance(save_strategy, str): + save_strategy = IntervalStrategy(save_strategy.lower()) + self.save_strategy = save_strategy + if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0): + raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!") + self.save_steps = save_steps + output_dir = Path(output_dir) + if hub_model_id is None: + repo_name = get_full_repo_name(output_dir.absolute().name, token=hub_token) + else: + repo_name = hub_model_id + self.output_dir = output_dir + self.repo = Repository(str(output_dir), clone_from=repo_name) + self.tokenizer = tokenizer + self.last_job = None + + def on_train_batch_end(self, batch, logs=None): + if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress steps {batch}", blocking=False + ) + + def on_epoch_end(self, epoch, logs=None): + if self.save_strategy == IntervalStrategy.EPOCH: + if self.last_job is not None and not self.last_job.is_done: + return # The last upload is still running, don't start another + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + _, self.last_job = self.repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False + ) + + def on_train_end(self, logs=None): + if self.last_job is not None and not self.last_job.is_done: + logger.info("Waiting for existing upload to finish...") + while not self.last_job.is_done: + sleep(1) + self.model.save_pretrained(self.output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.output_dir) + self.repo.push_to_hub(commit_message="End of training", blocking=True)