Push to hub when saving checkpoints (#13503)
* Push to hub when saving checkpoints * Add model card * Revert partial model card * Small fix for checkpoint * Add tests * Add documentation * Fix tests * Bump huggingface_hub * Fix test
This commit is contained in:
@@ -119,6 +119,29 @@ TFTrainingArguments
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
Checkpoints
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
By default, :class:`~transformers.Trainer` will save all checkpoints in the :obj:`output_dir` you set in the
|
||||||
|
:class:`~transformers.TrainingArguments` you are using. Those will go in subfolder named :obj:`checkpoint-xxx` with xxx
|
||||||
|
being the step at which the training was at.
|
||||||
|
|
||||||
|
Resuming training from a checkpoint can be done when calling :meth:`~transformers.Trainer.train` with either:
|
||||||
|
|
||||||
|
- :obj:`resume_from_checkpoint=True` which will resume training from the latest checkpoint
|
||||||
|
- :obj:`resume_from_checkpoint=checkpoint_dir` which will resume training from the specific checkpoint in the directory
|
||||||
|
passed.
|
||||||
|
|
||||||
|
In addition, you can easily save your checkpoints on the Model Hub when using :obj:`push_to_hub=True`. By default, all
|
||||||
|
the models saved in intermediate checkpoints are saved in different commits, but not the optimizer state. You can adapt
|
||||||
|
the :obj:`hub-strategy` value of your :class:`~transformers.TrainingArguments` to either:
|
||||||
|
|
||||||
|
- :obj:`"checkpoint"`: the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to
|
||||||
|
resume training easily with :obj:`trainer.train(resume_from_checkpoint="output_dir/last-checkpoint")`.
|
||||||
|
- :obj:`"all_checkpoints"`: all checkpoints are pushed like they appear in the output folder (so you will get one
|
||||||
|
checkpoint folder per folder in your final repository)
|
||||||
|
|
||||||
|
|
||||||
Logging
|
Logging
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -100,7 +100,7 @@ _deps = [
|
|||||||
"flax>=0.3.4",
|
"flax>=0.3.4",
|
||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"GitPython<3.1.19",
|
"GitPython<3.1.19",
|
||||||
"huggingface-hub>=0.0.12",
|
"huggingface-hub>=0.0.17",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ deps = {
|
|||||||
"flax": "flax>=0.3.4",
|
"flax": "flax>=0.3.4",
|
||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"GitPython": "GitPython<3.1.19",
|
"GitPython": "GitPython<3.1.19",
|
||||||
"huggingface-hub": "huggingface-hub>=0.0.12",
|
"huggingface-hub": "huggingface-hub>=0.0.17",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ from .trainer_utils import (
|
|||||||
EvalLoopOutput,
|
EvalLoopOutput,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
|
HubStrategy,
|
||||||
|
IntervalStrategy,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
ShardedDDPOption,
|
ShardedDDPOption,
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
@@ -180,6 +182,14 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Name of the files used for checkpointing
|
||||||
|
TRAINING_ARGS_NAME = "training_args.bin"
|
||||||
|
TRAINER_STATE_NAME = "trainer_state.json"
|
||||||
|
OPTIMIZER_NAME = "optimizer.pt"
|
||||||
|
SCHEDULER_NAME = "scheduler.pt"
|
||||||
|
SCALER_NAME = "scaler.pt"
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
||||||
@@ -389,6 +399,12 @@ class Trainer:
|
|||||||
# Create clone of distant repo and output directory if needed
|
# Create clone of distant repo and output directory if needed
|
||||||
if self.args.push_to_hub:
|
if self.args.push_to_hub:
|
||||||
self.init_git_repo()
|
self.init_git_repo()
|
||||||
|
# In case of pull, we need to make sure every process has the latest.
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
xm.rendezvous("init git repo")
|
||||||
|
elif args.local_rank != -1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -901,9 +917,9 @@ class Trainer:
|
|||||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
|
|
||||||
def call_model_init(self, trial=None):
|
def call_model_init(self, trial=None):
|
||||||
model_init_argcount = number_of_arguments(self.model_init)
|
model_init_argcount = number_of_arguments(self.model_init)
|
||||||
@@ -1183,9 +1199,9 @@ class Trainer:
|
|||||||
|
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if resume_from_checkpoint is not None and os.path.isfile(
|
if resume_from_checkpoint is not None and os.path.isfile(
|
||||||
os.path.join(resume_from_checkpoint, "trainer_state.json")
|
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||||
):
|
):
|
||||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||||
if not args.ignore_data_skip:
|
if not args.ignore_data_skip:
|
||||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||||
@@ -1520,9 +1536,9 @@ class Trainer:
|
|||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("saving_optimizer_states")
|
xm.rendezvous("saving_optimizer_states")
|
||||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
if smp.dp_rank() == 0:
|
if smp.dp_rank() == 0:
|
||||||
@@ -1530,20 +1546,20 @@ class Trainer:
|
|||||||
opt_state_dict = self.optimizer.state_dict()
|
opt_state_dict = self.optimizer.state_dict()
|
||||||
# Save it and the scheduler on the main process
|
# Save it and the scheduler on the main process
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
|
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
elif self.args.should_save and not self.deepspeed:
|
elif self.args.should_save and not self.deepspeed:
|
||||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
|
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
|
|
||||||
# Determine the new best metric / best model checkpoint
|
# Determine the new best metric / best model checkpoint
|
||||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||||
@@ -1563,7 +1579,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Save the Trainer state
|
# Save the Trainer state
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||||
|
|
||||||
# Save RNG state in non-distributed training
|
# Save RNG state in non-distributed training
|
||||||
rng_states = {
|
rng_states = {
|
||||||
@@ -1590,6 +1606,9 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||||
|
|
||||||
|
if self.args.push_to_hub:
|
||||||
|
self._push_from_checkpoint(output_dir)
|
||||||
|
|
||||||
# Maybe delete some older checkpoints.
|
# Maybe delete some older checkpoints.
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||||
@@ -1603,15 +1622,15 @@ class Trainer:
|
|||||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||||
return
|
return
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
|
||||||
os.path.join(checkpoint, "scheduler.pt")
|
os.path.join(checkpoint, SCHEDULER_NAME)
|
||||||
):
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||||
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
|
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||||
@@ -1622,13 +1641,13 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
|
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
|
||||||
self.optimizer.load_state_dict(
|
self.optimizer.load_state_dict(
|
||||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
|
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||||
)
|
)
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
if self.use_amp and os.path.isfile(os.path.join(checkpoint, "scaler.pt")):
|
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, "scaler.pt")))
|
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
@@ -1908,7 +1927,7 @@ class Trainer:
|
|||||||
|
|
||||||
if xm.is_master_ordinal():
|
if xm.is_master_ordinal():
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
@@ -1953,7 +1972,7 @@ class Trainer:
|
|||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|
||||||
def store_flos(self):
|
def store_flos(self):
|
||||||
# Storing the number of floating-point operations that went into the model
|
# Storing the number of floating-point operations that went into the model
|
||||||
@@ -2476,9 +2495,9 @@ class Trainer:
|
|||||||
|
|
||||||
def init_git_repo(self):
|
def init_git_repo(self):
|
||||||
"""
|
"""
|
||||||
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
|
Initializes a git repo in :obj:`self.args.hub_model_id`.
|
||||||
"""
|
"""
|
||||||
if not self.args.should_save:
|
if not self.is_world_process_zero():
|
||||||
return
|
return
|
||||||
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
|
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
|
||||||
if self.args.hub_model_id is None:
|
if self.args.hub_model_id is None:
|
||||||
@@ -2486,17 +2505,36 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
repo_name = self.args.hub_model_id
|
repo_name = self.args.hub_model_id
|
||||||
|
|
||||||
|
try:
|
||||||
self.repo = Repository(
|
self.repo = Repository(
|
||||||
self.args.output_dir,
|
self.args.output_dir,
|
||||||
clone_from=repo_name,
|
clone_from=repo_name,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
if self.args.overwrite_output_dir:
|
||||||
|
# Try again after wiping output_dir
|
||||||
|
shutil.rmtree(self.args.output_dir)
|
||||||
|
self.repo = Repository(
|
||||||
|
self.args.output_dir,
|
||||||
|
clone_from=repo_name,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.repo.git_pull()
|
||||||
|
|
||||||
# By default, ignore the checkpoint folders
|
# By default, ignore the checkpoint folders
|
||||||
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
if (
|
||||||
|
not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
|
||||||
|
and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
|
||||||
|
):
|
||||||
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
|
||||||
writer.writelines(["checkpoint-*/"])
|
writer.writelines(["checkpoint-*/"])
|
||||||
|
|
||||||
|
self.push_in_progress = None
|
||||||
|
|
||||||
def create_model_card(
|
def create_model_card(
|
||||||
self,
|
self,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
@@ -2525,18 +2563,61 @@ class Trainer:
|
|||||||
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
|
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
|
||||||
f.write(model_card)
|
f.write(model_card)
|
||||||
|
|
||||||
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
def _push_from_checkpoint(self, checkpoint_folder):
|
||||||
|
# Only push from one node.
|
||||||
|
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
|
||||||
|
return
|
||||||
|
# If we haven't finished the last push, we don't do this one.
|
||||||
|
if self.push_in_progress is not None and not self.push_in_progress.is_done:
|
||||||
|
return
|
||||||
|
|
||||||
|
output_dir = self.args.output_dir
|
||||||
|
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
|
||||||
|
modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
|
||||||
|
for modeling_file in modeling_files:
|
||||||
|
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
|
||||||
|
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
|
||||||
|
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
# Same for the training arguments
|
||||||
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
|
||||||
|
# Temporarily move the checkpoint just saved for the push
|
||||||
|
tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
|
||||||
|
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
|
||||||
|
# subfolder.
|
||||||
|
if os.path.isdir(tmp_checkpoint):
|
||||||
|
shutil.rmtree(tmp_checkpoint)
|
||||||
|
shutil.move(checkpoint_folder, tmp_checkpoint)
|
||||||
|
|
||||||
|
if self.args.save_strategy == IntervalStrategy.STEPS:
|
||||||
|
commit_message = f"Training in progress, step {self.state.global_step}"
|
||||||
|
else:
|
||||||
|
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
|
||||||
|
_, self.push_in_progress = self.repo.push_to_hub(commit_message=commit_message, blocking=False)
|
||||||
|
finally:
|
||||||
|
if self.args.hub_strategy == HubStrategy.CHECKPOINT:
|
||||||
|
# Move back the checkpoint to its place
|
||||||
|
shutil.move(tmp_checkpoint, checkpoint_folder)
|
||||||
|
|
||||||
|
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
commit_message (:obj:`str`, `optional`, defaults to :obj:`"End of training"`):
|
||||||
Message to commit while pushing.
|
Message to commit while pushing.
|
||||||
|
blocking (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether the function should return only when the :obj:`git push` has finished.
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
|
Additional keyword arguments passed along to :meth:`~transformers.Trainer.create_model_card`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The url of the commit of your model in the given repository.
|
The url of the commit of your model in the given repository if :obj:`blocking=False`, a tuple with the url
|
||||||
|
of the commit and an object to track the progress of the commit if :obj:`blocking=True`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
@@ -2553,7 +2634,7 @@ class Trainer:
|
|||||||
if not self.is_world_process_zero():
|
if not self.is_world_process_zero():
|
||||||
return
|
return
|
||||||
|
|
||||||
return self.repo.push_to_hub(commit_message=commit_message)
|
return self.repo.push_to_hub(commit_message=commit_message, blocking=blocking)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Deprecated code
|
# Deprecated code
|
||||||
|
|||||||
@@ -125,6 +125,13 @@ class EvaluationStrategy(ExplicitEnum):
|
|||||||
EPOCH = "epoch"
|
EPOCH = "epoch"
|
||||||
|
|
||||||
|
|
||||||
|
class HubStrategy(ExplicitEnum):
|
||||||
|
END = "end"
|
||||||
|
EVERY_SAVE = "every_save"
|
||||||
|
CHECKPOINT = "checkpoint"
|
||||||
|
ALL_CHECKPOINTS = "all_checkpoints"
|
||||||
|
|
||||||
|
|
||||||
class BestRun(NamedTuple):
|
class BestRun(NamedTuple):
|
||||||
"""
|
"""
|
||||||
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from .file_utils import (
|
|||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -343,6 +343,22 @@ class TrainingArguments:
|
|||||||
|
|
||||||
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
|
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
|
||||||
:obj:`output_dir`.
|
:obj:`output_dir`.
|
||||||
|
hub_strategy (:obj:`str` or :class:`~transformers.trainer_utils.HubStrategy`, `optional`, defaults to :obj:`"every_save"`):
|
||||||
|
Defines the scope of what is pushed to the Hub and when. Possible values are:
|
||||||
|
|
||||||
|
- :obj:`"end"`: push the model, its configuration, the tokenizer (if passed along to the
|
||||||
|
:class:`~transformers.Trainer`) and a draft of a model card at the end of training.
|
||||||
|
- :obj:`"every_save"`: push the model, its configuration, the tokenizer (if passed along to the
|
||||||
|
:class:`~transformers.Trainer`) and a draft of a model card each time there is a model save. The pushes
|
||||||
|
are asynchronous to not block training, and in case the save are very frequent, a new push is only
|
||||||
|
attempted if the previous one is finished. A last push is made with the final model at the end of
|
||||||
|
training.
|
||||||
|
- :obj:`"checkpoint"`: like :obj:`"every_save"` but the latest checkpoint is also pushed in a subfolder
|
||||||
|
named last-checkpoint, allowing you to resume training easily with
|
||||||
|
:obj:`trainer.train(resume_from_checkpoint="last-checkpoint")`.
|
||||||
|
- :obj:`"all_checkpoints"`: like :obj:`"checkpoint"` but all checkpoints are pushed like they appear in the
|
||||||
|
output folder (so you will get one checkpoint folder per folder in your final repository)
|
||||||
|
|
||||||
hub_token (:obj:`str`, `optional`):
|
hub_token (:obj:`str`, `optional`):
|
||||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||||
:obj:`huggingface-cli login`.
|
:obj:`huggingface-cli login`.
|
||||||
@@ -618,6 +634,10 @@ class TrainingArguments:
|
|||||||
hub_model_id: str = field(
|
hub_model_id: str = field(
|
||||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
)
|
)
|
||||||
|
hub_strategy: HubStrategy = field(
|
||||||
|
default="every_save",
|
||||||
|
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
|
||||||
|
)
|
||||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
# Deprecated arguments
|
# Deprecated arguments
|
||||||
push_to_hub_model_id: str = field(
|
push_to_hub_model_id: str = field(
|
||||||
@@ -668,6 +688,7 @@ class TrainingArguments:
|
|||||||
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||||
|
self.hub_strategy = HubStrategy(self.hub_strategy)
|
||||||
|
|
||||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||||
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||||
|
|||||||
@@ -18,13 +18,14 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi, Repository
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -1284,8 +1285,9 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
|
||||||
try:
|
try:
|
||||||
cls._api.delete_repo(token=cls._token, name="test-trainer")
|
cls._api.delete_repo(token=cls._token, name=model)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
self.assertEqual(model.a.item(), trainer.model.a.item())
|
self.assertEqual(model.a.item(), trainer.model.a.item())
|
||||||
self.assertEqual(model.b.item(), trainer.model.b.item())
|
self.assertEqual(model.b.item(), trainer.model.b.item())
|
||||||
|
|
||||||
|
def get_commit_history(self, repo):
|
||||||
|
commit_logs = subprocess.run(
|
||||||
|
"git log".split(),
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
check=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
cwd=repo,
|
||||||
|
).stdout
|
||||||
|
commits = commit_logs.split("\n\n")[1::2]
|
||||||
|
return [commit.strip() for commit in commits]
|
||||||
|
|
||||||
|
def test_push_to_hub_with_saves_each_epoch(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=os.path.join(tmp_dir, "test-trainer-epoch"),
|
||||||
|
push_to_hub=True,
|
||||||
|
hub_token=self._token,
|
||||||
|
save_strategy="epoch",
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
|
||||||
|
commits = self.get_commit_history(tmp_dir)
|
||||||
|
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
|
||||||
|
expected_commits.append("initial commit")
|
||||||
|
self.assertListEqual(commits, expected_commits)
|
||||||
|
print(commits, len(commits))
|
||||||
|
|
||||||
|
def test_push_to_hub_with_saves_each_n_steps(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
|
||||||
|
push_to_hub=True,
|
||||||
|
hub_token=self._token,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=5,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
|
||||||
|
commits = self.get_commit_history(tmp_dir)
|
||||||
|
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
|
||||||
|
expected_commits.append("initial commit")
|
||||||
|
self.assertListEqual(commits, expected_commits)
|
||||||
|
print(commits, len(commits))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
Reference in New Issue
Block a user