Push to hub when saving checkpoints (#13503)
* Push to hub when saving checkpoints * Add model card * Revert partial model card * Small fix for checkpoint * Add tests * Add documentation * Fix tests * Bump huggingface_hub * Fix test
This commit is contained in:
@@ -119,6 +119,29 @@ TFTrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
Checkpoints
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
By default, :class:`~transformers.Trainer` will save all checkpoints in the :obj:`output_dir` you set in the
|
||||
:class:`~transformers.TrainingArguments` you are using. Those will go in subfolder named :obj:`checkpoint-xxx` with xxx
|
||||
being the step at which the training was at.
|
||||
|
||||
Resuming training from a checkpoint can be done when calling :meth:`~transformers.Trainer.train` with either:
|
||||
|
||||
- :obj:`resume_from_checkpoint=True` which will resume training from the latest checkpoint
|
||||
- :obj:`resume_from_checkpoint=checkpoint_dir` which will resume training from the specific checkpoint in the directory
|
||||
passed.
|
||||
|
||||
In addition, you can easily save your checkpoints on the Model Hub when using :obj:`push_to_hub=True`. By default, all
|
||||
the models saved in intermediate checkpoints are saved in different commits, but not the optimizer state. You can adapt
|
||||
the :obj:`hub-strategy` value of your :class:`~transformers.TrainingArguments` to either:
|
||||
|
||||
- :obj:`"checkpoint"`: the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to
|
||||
resume training easily with :obj:`trainer.train(resume_from_checkpoint="output_dir/last-checkpoint")`.
|
||||
- :obj:`"all_checkpoints"`: all checkpoints are pushed like they appear in the output folder (so you will get one
|
||||
checkpoint folder per folder in your final repository)
|
||||
|
||||
|
||||
Logging
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -100,7 +100,7 @@ _deps = [
|
||||
"flax>=0.3.4",
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"huggingface-hub>=0.0.12",
|
||||
"huggingface-hub>=0.0.17",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
|
||||
@@ -18,7 +18,7 @@ deps = {
|
||||
"flax": "flax>=0.3.4",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"huggingface-hub": "huggingface-hub>=0.0.12",
|
||||
"huggingface-hub": "huggingface-hub>=0.0.17",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
|
||||
@@ -110,6 +110,8 @@ from .trainer_utils import (
|
||||
EvalLoopOutput,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
@@ -180,6 +182,14 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Name of the files used for checkpointing
|
||||
TRAINING_ARGS_NAME = "training_args.bin"
|
||||
TRAINER_STATE_NAME = "trainer_state.json"
|
||||
OPTIMIZER_NAME = "optimizer.pt"
|
||||
SCHEDULER_NAME = "scheduler.pt"
|
||||
SCALER_NAME = "scaler.pt"
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
||||
@@ -389,6 +399,12 @@ class Trainer:
|
||||
# Create clone of distant repo and output directory if needed
|
||||
if self.args.push_to_hub:
|
||||
self.init_git_repo()
|
||||
# In case of pull, we need to make sure every process has the latest.
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("init git repo")
|
||||
elif args.local_rank != -1:
|
||||
dist.barrier()
|
||||
|
||||
if self.args.should_save:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
@@ -901,9 +917,9 @@ class Trainer:
|
||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir)
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
|
||||
def call_model_init(self, trial=None):
|
||||
model_init_argcount = number_of_arguments(self.model_init)
|
||||
@@ -1183,9 +1199,9 @@ class Trainer:
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if resume_from_checkpoint is not None and os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, "trainer_state.json")
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
):
|
||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
|
||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||
if not args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
@@ -1520,9 +1536,9 @@ class Trainer:
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
if smp.dp_rank() == 0:
|
||||
@@ -1530,20 +1546,20 @@ class Trainer:
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.args.should_save:
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
@@ -1563,7 +1579,7 @@ class Trainer:
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
# Save RNG state in non-distributed training
|
||||
rng_states = {
|
||||
@@ -1590,6 +1606,9 @@ class Trainer:
|
||||
else:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
@@ -1603,15 +1622,15 @@ class Trainer:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(checkpoint, "scheduler.pt")
|
||||
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
|
||||
os.path.join(checkpoint, SCHEDULER_NAME)
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
@@ -1622,13 +1641,13 @@ class Trainer:
|
||||
else:
|
||||
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")):
|
||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt")))
|
||||
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
@@ -1908,7 +1927,7 @@ class Trainer:
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
@@ -1953,7 +1972,7 @@ class Trainer:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
def store_flos(self):
|
||||
# Storing the number of floating-point operations that went into the model
|
||||
@@ -2476,9 +2495,9 @@ class Trainer:
|
||||
|
||||
def init_git_repo(self):
|
||||
"""
|
||||
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
|
||||
Initializes a git repo in :obj:`self.args.hub_model_id`.
|
||||
"""
|
||||
if not self.args.should_save:
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
|
||||
if self.args.hub_model_id is None:
|
||||
@@ -2486,17 +2505,36 @@ class Trainer:
|
||||
else:
|
||||
repo_name = self.args.hub_model_id
|
||||
|
||||
try:
|
||||
self.repo = Repository(
|
||||
self.args.output_dir,
|
||||
clone_from=repo_name,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
except EnvironmentError:
|
||||
if self.args.overwrite_output_dir:
|
||||
# Try again after wiping output_dir
|
||||
shutil.rmtree(self.args.output_dir)
|
||||
self.repo = Repository(
|
||||
self.args.output_dir,
|
||||
clone_from=repo_name,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self.repo.git_pull()
|
||||
|
||||
# By default, ignore the checkpoint folders
|
||||
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
||||
if (
|
||||
not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
|
||||
and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
|
||||
):
|
||||
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
||||
writer.writelines(["checkpoint-*/"])
|
||||
|
||||
self.push_in_progress = None
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
language: Optional[str] = None,
|
||||
@@ -2525,18 +2563,61 @@ class Trainer:
|
||||
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
|
||||
f.write(model_card)
|
||||
|
||||
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
||||
def _push_from_checkpoint(self, checkpoint_folder):
|
||||
# Only push from one node.
|
||||
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
|
||||
return
|
||||
# If we haven't finished the last push, we don't do this one.
|
||||
if self.push_in_progress is not None and not self.push_in_progress.is_done:
|
||||
return
|
||||
|
||||
output_dir = self.args.output_dir
|
||||
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
|
||||
modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
|
||||
for modeling_file in modeling_files:
|
||||
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
|
||||
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
|
||||
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
# Same for the training arguments
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
try:
|
||||
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
|
||||
# Temporarily move the checkpoint just saved for the push
|
||||
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
|
||||
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
|
||||
# subfolder.
|
||||
if os.path.isdir(tmp_checkpoint):
|
||||
shutil.rmtree(tmp_checkpoint)
|
||||
shutil.move(checkpoint_folder, tmp_checkpoint)
|
||||
|
||||
if self.args.save_strategy == IntervalStrategy.STEPS:
|
||||
commit_message = f"Training in progress, step {self.state.global_step}"
|
||||
else:
|
||||
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
|
||||
_, self.push_in_progress = self.repo.push_to_hub(commit_message=commit_message, blocking=False)
|
||||
finally:
|
||||
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
|
||||
# Move back the checkpoint to its place
|
||||
shutil.move(tmp_checkpoint, checkpoint_folder)
|
||||
|
||||
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||
|
||||
Parameters:
|
||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"End of training"`):
|
||||
Message to commit while pushing.
|
||||
blocking (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether the function should return only when the :obj:`git push` has finished.
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
|
||||
|
||||
Returns:
|
||||
The url of the commit of your model in the given repository.
|
||||
The url of the commit of your model in the given repository if :obj:`blocking=False`, a tuple with the url
|
||||
of the commit and an object to track the progress of the commit if :obj:`blocking=True`
|
||||
"""
|
||||
|
||||
if self.args.should_save:
|
||||
@@ -2553,7 +2634,7 @@ class Trainer:
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
return self.repo.push_to_hub(commit_message=commit_message)
|
||||
return self.repo.push_to_hub(commit_message=commit_message, blocking=blocking)
|
||||
|
||||
#
|
||||
# Deprecated code
|
||||
|
||||
@@ -125,6 +125,13 @@ class EvaluationStrategy(ExplicitEnum):
|
||||
EPOCH = "epoch"
|
||||
|
||||
|
||||
class HubStrategy(ExplicitEnum):
|
||||
END = "end"
|
||||
EVERY_SAVE = "every_save"
|
||||
CHECKPOINT = "checkpoint"
|
||||
ALL_CHECKPOINTS = "all_checkpoints"
|
||||
|
||||
|
||||
class BestRun(NamedTuple):
|
||||
"""
|
||||
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||
|
||||
@@ -32,7 +32,7 @@ from .file_utils import (
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -343,6 +343,22 @@ class TrainingArguments:
|
||||
|
||||
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
|
||||
:obj:`output_dir`.
|
||||
hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`):
|
||||
Defines the scope of what is pushed to the Hub and when. Possible values are:
|
||||
|
||||
- :obj:`"end"`: push the model, its configuration, the tokenizer (if passed along to the
|
||||
:class:`~transformers.Trainer`) and a draft of a model card at the end of training.
|
||||
- :obj:`"every_save"`: push the model, its configuration, the tokenizer (if passed along to the
|
||||
:class:`~transformers.Trainer`) and a draft of a model card each time there is a model save. The pushes
|
||||
are asynchronous to not block training, and in case the save are very frequent, a new push is only
|
||||
attempted if the previous one is finished. A last push is made with the final model at the end of
|
||||
training.
|
||||
- :obj:`"checkpoint"`: like :obj:`"every_save"` but the latest checkpoint is also pushed in a subfolder
|
||||
named last-checkpoint, allowing you to resume training easily with
|
||||
:obj:`trainer.train(resume_from_checkpoint="last-checkpoint")`.
|
||||
- :obj:`"all_checkpoints"`: like :obj:`"checkpoint"` but all checkpoints are pushed like they appear in the
|
||||
output folder (so you will get one checkpoint folder per folder in your final repository)
|
||||
|
||||
hub_token (:obj:`str`, `optional`):
|
||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||
:obj:`huggingface-cli login`.
|
||||
@@ -618,6 +634,10 @@ class TrainingArguments:
|
||||
hub_model_id: str = field(
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_strategy: HubStrategy = field(
|
||||
default="every_save",
|
||||
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
|
||||
)
|
||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
# Deprecated arguments
|
||||
push_to_hub_model_id: str = field(
|
||||
@@ -668,6 +688,7 @@ class TrainingArguments:
|
||||
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||
self.hub_strategy = HubStrategy(self.hub_strategy)
|
||||
|
||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||
|
||||
@@ -18,13 +18,14 @@ import gc
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -1284,8 +1285,9 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-trainer")
|
||||
cls._api.delete_repo(token=cls._token, name=model)
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
self.assertEqual(model.a.item(), trainer.model.a.item())
|
||||
self.assertEqual(model.b.item(), trainer.model.b.item())
|
||||
|
||||
def get_commit_history(self, repo):
|
||||
commit_logs = subprocess.run(
|
||||
"git log".split(),
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
cwd=repo,
|
||||
).stdout
|
||||
commits = commit_logs.split("\n\n")[1::2]
|
||||
return [commit.strip() for commit in commits]
|
||||
|
||||
def test_push_to_hub_with_saves_each_epoch(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer-epoch"),
|
||||
push_to_hub=True,
|
||||
hub_token=self._token,
|
||||
save_strategy="epoch",
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
|
||||
commits = self.get_commit_history(tmp_dir)
|
||||
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
|
||||
expected_commits.append("initial commit")
|
||||
self.assertListEqual(commits, expected_commits)
|
||||
print(commits, len(commits))
|
||||
|
||||
def test_push_to_hub_with_saves_each_n_steps(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
|
||||
push_to_hub=True,
|
||||
hub_token=self._token,
|
||||
save_strategy="steps",
|
||||
save_steps=5,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
|
||||
commits = self.get_commit_history(tmp_dir)
|
||||
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
|
||||
expected_commits.append("initial commit")
|
||||
self.assertListEqual(commits, expected_commits)
|
||||
print(commits, len(commits))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
|
||||
Reference in New Issue
Block a user