Deprecate model_path in Trainer.train (#9854)
This commit is contained in:
@@ -362,12 +362,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -403,12 +403,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -355,12 +355,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -384,12 +384,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -342,12 +342,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -463,12 +463,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -502,12 +502,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -491,12 +491,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -399,12 +399,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
metrics = train_result.metrics
|
metrics = train_result.metrics
|
||||||
|
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|||||||
@@ -380,12 +380,12 @@ def main():
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
def _objective(trial, checkpoint_dir=None):
|
def _objective(trial, checkpoint_dir=None):
|
||||||
model_path = None
|
checkpoint = None
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
for subdir in os.listdir(checkpoint_dir):
|
for subdir in os.listdir(checkpoint_dir):
|
||||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||||
model_path = os.path.join(checkpoint_dir, subdir)
|
checkpoint = os.path.join(checkpoint_dir, subdir)
|
||||||
trainer.objective = None
|
trainer.objective = None
|
||||||
trainer.train(model_path=model_path, trial=trial)
|
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||||
# If there hasn't been any evaluation during the training loop.
|
# If there hasn't been any evaluation during the training loop.
|
||||||
if getattr(trainer, "objective", None) is None:
|
if getattr(trainer, "objective", None) is None:
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
import ray
|
import ray
|
||||||
|
|
||||||
def _objective(trial, local_trainer, checkpoint_dir=None):
|
def _objective(trial, local_trainer, checkpoint_dir=None):
|
||||||
model_path = None
|
checkpoint = None
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
for subdir in os.listdir(checkpoint_dir):
|
for subdir in os.listdir(checkpoint_dir):
|
||||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||||
model_path = os.path.join(checkpoint_dir, subdir)
|
checkpoint = os.path.join(checkpoint_dir, subdir)
|
||||||
local_trainer.objective = None
|
local_trainer.objective = None
|
||||||
local_trainer.train(model_path=model_path, trial=trial)
|
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||||
# If there hasn't been any evaluation during the training loop.
|
# If there hasn't been any evaluation during the training loop.
|
||||||
if getattr(local_trainer, "objective", None) is None:
|
if getattr(local_trainer, "objective", None) is None:
|
||||||
metrics = local_trainer.evaluate()
|
metrics = local_trainer.evaluate()
|
||||||
|
|||||||
@@ -676,17 +676,33 @@ class Trainer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
def train(
|
||||||
|
self,
|
||||||
|
resume_from_checkpoint: Optional[str] = None,
|
||||||
|
trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Main training entry point.
|
Main training entry point.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (:obj:`str`, `optional`):
|
resume_from_checkpoint (:obj:`str`, `optional`):
|
||||||
Local path to the model if the model to train has been instantiated from a local path. If present,
|
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
|
||||||
training will resume from the optimizer/scheduler states loaded here.
|
present, training will resume from the model/optimizer/scheduler states loaded here.
|
||||||
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||||
The trial run or the hyperparameter dictionary for hyperparameter search.
|
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||||
|
kwargs:
|
||||||
|
Additional keyword arguments used to hide deprecated arguments
|
||||||
"""
|
"""
|
||||||
|
if "model_path" in kwargs:
|
||||||
|
resume_from_checkpoint = kwargs.pop("model_path")
|
||||||
|
warnings.warn(
|
||||||
|
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
|
||||||
|
"instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
|
||||||
# This might change the seed so needs to run first.
|
# This might change the seed so needs to run first.
|
||||||
self._hp_search_setup(trial)
|
self._hp_search_setup(trial)
|
||||||
|
|
||||||
@@ -701,13 +717,13 @@ class Trainer:
|
|||||||
self.optimizer, self.lr_scheduler = None, None
|
self.optimizer, self.lr_scheduler = None, None
|
||||||
|
|
||||||
# Load potential model checkpoint
|
# Load potential model checkpoint
|
||||||
if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)):
|
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||||
logger.info(f"Loading model from {model_path}).")
|
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||||
if isinstance(self.model, PreTrainedModel):
|
if isinstance(self.model, PreTrainedModel):
|
||||||
self.model = self.model.from_pretrained(model_path)
|
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
||||||
model_reloaded = True
|
model_reloaded = True
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
@@ -757,7 +773,7 @@ class Trainer:
|
|||||||
self.state.is_hyper_param_search = trial is not None
|
self.state.is_hyper_param_search = trial is not None
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(model_path)
|
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||||
|
|
||||||
model = self.model_wrapped
|
model = self.model_wrapped
|
||||||
|
|
||||||
@@ -827,8 +843,10 @@ class Trainer:
|
|||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
|
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
if resume_from_checkpoint is not None and os.path.isfile(
|
||||||
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
os.path.join(resume_from_checkpoint, "trainer_state.json")
|
||||||
|
):
|
||||||
|
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
|
||||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||||
if not self.args.ignore_data_skip:
|
if not self.args.ignore_data_skip:
|
||||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||||
@@ -1102,20 +1120,20 @@ class Trainer:
|
|||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
self._rotate_checkpoints(use_mtime=True)
|
self._rotate_checkpoints(use_mtime=True)
|
||||||
|
|
||||||
def _load_optimizer_and_scheduler(self, model_path):
|
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||||
"""If optimizer and scheduler states exist, load them."""
|
"""If optimizer and scheduler states exist, load them."""
|
||||||
if model_path is None:
|
if checkpoint is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(
|
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||||
os.path.join(model_path, "scheduler.pt")
|
os.path.join(checkpoint, "scheduler.pt")
|
||||||
):
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||||
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
|
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
|
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||||
@@ -1125,15 +1143,15 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||||
else:
|
else:
|
||||||
self.optimizer.load_state_dict(
|
self.optimizer.load_state_dict(
|
||||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
|
||||||
)
|
)
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
|
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
|
||||||
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True)
|
self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -341,20 +341,20 @@ def main():
|
|||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
{%- if cookiecutter.can_train_from_scratch == "False" %}
|
{%- if cookiecutter.can_train_from_scratch == "False" %}
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif os.path.isdir(model_args.model_name_or_path):
|
elif os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
{%- elif cookiecutter.can_train_from_scratch == "True" %}
|
{%- elif cookiecutter.can_train_from_scratch == "True" %}
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
model_path = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||||
model_path = model_args.model_name_or_path
|
checkpoint = model_args.model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_path = None
|
checkpoint = None
|
||||||
{% endif %}
|
{% endif %}
|
||||||
train_result = trainer.train(model_path=model_path)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||||
|
|||||||
@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
# Reinitialize trainer
|
# Reinitialize trainer
|
||||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state1 = dataclasses.asdict(trainer.state)
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
self.assertEqual(a, a1)
|
self.assertEqual(a, a1)
|
||||||
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
# Reinitialize trainer and load model
|
# Reinitialize trainer and load model
|
||||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state1 = dataclasses.asdict(trainer.state)
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
self.assertEqual(a, a1)
|
self.assertEqual(a, a1)
|
||||||
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state1 = dataclasses.asdict(trainer.state)
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
self.assertEqual(a, a1)
|
self.assertEqual(a, a1)
|
||||||
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state1 = dataclasses.asdict(trainer.state)
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
self.assertEqual(a, a1)
|
self.assertEqual(a, a1)
|
||||||
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(model_path=checkpoint)
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
state1 = dataclasses.asdict(trainer.state)
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
self.assertEqual(a, a1)
|
self.assertEqual(a, a1)
|
||||||
|
|||||||
Reference in New Issue
Block a user