Add proper documentation for Keras callbacks (#15374)
* Add proper documentation for Keras callbacks * Add dummies
This commit is contained in:
@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License.
|
|||||||
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
|
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
|
||||||
tasks:
|
tasks:
|
||||||
|
|
||||||
|
## KerasMetricCallback
|
||||||
|
|
||||||
|
[[autodoc]] KerasMetricCallback
|
||||||
|
|
||||||
## PushToHubCallback
|
## PushToHubCallback
|
||||||
|
|
||||||
[[autodoc]] keras_callbacks.PushToHubCallback
|
[[autodoc]] PushToHubCallback
|
||||||
|
|||||||
@@ -1550,7 +1550,7 @@ if is_tf_available():
|
|||||||
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
||||||
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
||||||
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
||||||
_import_structure["keras_callbacks"] = ["PushToHubCallback"]
|
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
|
||||||
_import_structure["modeling_tf_outputs"] = []
|
_import_structure["modeling_tf_outputs"] = []
|
||||||
_import_structure["modeling_tf_utils"] = [
|
_import_structure["modeling_tf_utils"] = [
|
||||||
"TFPreTrainedModel",
|
"TFPreTrainedModel",
|
||||||
@@ -3486,7 +3486,7 @@ if TYPE_CHECKING:
|
|||||||
# Benchmarks
|
# Benchmarks
|
||||||
from .benchmark.benchmark_tf import TensorFlowBenchmark
|
from .benchmark.benchmark_tf import TensorFlowBenchmark
|
||||||
from .generation_tf_utils import tf_top_k_top_p_filtering
|
from .generation_tf_utils import tf_top_k_top_p_filtering
|
||||||
from .keras_callbacks import PushToHubCallback
|
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||||
from .modeling_tf_layoutlm import (
|
from .modeling_tf_layoutlm import (
|
||||||
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFLayoutLMForMaskedLM,
|
TFLayoutLMForMaskedLM,
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback):
|
|||||||
supplied.
|
supplied.
|
||||||
batch_size (`int`, *optional*):
|
batch_size (`int`, *optional*):
|
||||||
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
|
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
|
||||||
predict_with_generate: (`bool`, *optional*, defaults to `False`):
|
predict_with_generate (`bool`, *optional*, defaults to `False`):
|
||||||
Whether we should use `model.generate()` to get outputs for the model.
|
Whether we should use `model.generate()` to get outputs for the model.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -240,6 +240,38 @@ class KerasMetricCallback(Callback):
|
|||||||
|
|
||||||
|
|
||||||
class PushToHubCallback(Callback):
|
class PushToHubCallback(Callback):
|
||||||
|
"""
|
||||||
|
Callback that will save and push the model to the Hub regularly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir (`str`):
|
||||||
|
The output directory where the model predictions and checkpoints will be written and synced with the
|
||||||
|
repository on the Hub.
|
||||||
|
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
|
||||||
|
The checkpoint save strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
|
- `"no"`: No save is done during training.
|
||||||
|
- `"epoch"`: Save is done at the end of each epoch.
|
||||||
|
- `"steps"`: Save is done every `save_steps`
|
||||||
|
save_steps (`int`, *optional*):
|
||||||
|
The number of steps between saves when using the "steps" `save_strategy`.
|
||||||
|
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
||||||
|
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
|
||||||
|
hub_model_id (`str`, *optional*):
|
||||||
|
The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
|
||||||
|
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
|
||||||
|
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
|
||||||
|
`"organization_name/model"`.
|
||||||
|
|
||||||
|
Will default to to the name of `output_dir`.
|
||||||
|
hub_token (`str`, *optional*):
|
||||||
|
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||||
|
`huggingface-cli login`.
|
||||||
|
checkpoint (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
|
||||||
|
resumed. Only usable when `save_strategy` is `"epoch"`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_dir: Union[str, Path],
|
output_dir: Union[str, Path],
|
||||||
@@ -251,34 +283,6 @@ class PushToHubCallback(Callback):
|
|||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
**model_card_args
|
**model_card_args
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
output_dir (`str`):
|
|
||||||
The output directory where the model predictions and checkpoints will be written and synced with the
|
|
||||||
repository on the Hub.
|
|
||||||
save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
|
|
||||||
The checkpoint save strategy to adopt during training. Possible values are:
|
|
||||||
|
|
||||||
- `"no"`: No save is done during training.
|
|
||||||
- `"epoch"`: Save is done at the end of each epoch.
|
|
||||||
- `"steps"`: Save is done every `save_steps`
|
|
||||||
save_steps (`int`, *optional*):
|
|
||||||
The number of steps between saves when using the "steps" save_strategy.
|
|
||||||
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
|
||||||
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
|
|
||||||
hub_model_id (`str`, *optional*):
|
|
||||||
The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
|
|
||||||
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
|
|
||||||
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
|
|
||||||
`"organization_name/model"`.
|
|
||||||
|
|
||||||
Will default to to the name of `output_dir`.
|
|
||||||
hub_token (`str`, *optional*):
|
|
||||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
|
||||||
`huggingface-cli login`.
|
|
||||||
checkpoint (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
|
|
||||||
resumed. Only usable when *save_strategy* is *epoch*.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if checkpoint and save_strategy != "epoch":
|
if checkpoint and save_strategy != "epoch":
|
||||||
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
|
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
|
||||||
|
|||||||
@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
|
|||||||
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
requires_backends(tf_top_k_top_p_filtering, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class KerasMetricCallback(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class PushToHubCallback(metaclass=DummyObject):
|
class PushToHubCallback(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user