Deprecate model_path in Trainer.train (#9854)
This commit is contained in:
@@ -362,12 +362,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -403,12 +403,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -355,12 +355,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -384,12 +384,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -342,12 +342,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -463,12 +463,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -502,12 +502,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -491,12 +491,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -399,12 +399,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
@@ -380,12 +380,12 @@ def main():
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
checkpoint = None
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -125,13 +125,13 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
import optuna
|
||||
|
||||
def _objective(trial, checkpoint_dir=None):
|
||||
model_path = None
|
||||
checkpoint = None
|
||||
if checkpoint_dir:
|
||||
for subdir in os.listdir(checkpoint_dir):
|
||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||
model_path = os.path.join(checkpoint_dir, subdir)
|
||||
checkpoint = os.path.join(checkpoint_dir, subdir)
|
||||
trainer.objective = None
|
||||
trainer.train(model_path=model_path, trial=trial)
|
||||
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(trainer, "objective", None) is None:
|
||||
metrics = trainer.evaluate()
|
||||
@@ -150,13 +150,13 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
import ray
|
||||
|
||||
def _objective(trial, local_trainer, checkpoint_dir=None):
|
||||
model_path = None
|
||||
checkpoint = None
|
||||
if checkpoint_dir:
|
||||
for subdir in os.listdir(checkpoint_dir):
|
||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||
model_path = os.path.join(checkpoint_dir, subdir)
|
||||
checkpoint = os.path.join(checkpoint_dir, subdir)
|
||||
local_trainer.objective = None
|
||||
local_trainer.train(model_path=model_path, trial=trial)
|
||||
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(local_trainer, "objective", None) is None:
|
||||
metrics = local_trainer.evaluate()
|
||||
|
||||
@@ -676,17 +676,33 @@ class Trainer:
|
||||
|
||||
return model
|
||||
|
||||
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Main training entry point.
|
||||
|
||||
Args:
|
||||
model_path (:obj:`str`, `optional`):
|
||||
Local path to the model if the model to train has been instantiated from a local path. If present,
|
||||
training will resume from the optimizer/scheduler states loaded here.
|
||||
resume_from_checkpoint (:obj:`str`, `optional`):
|
||||
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
|
||||
present, training will resume from the model/optimizer/scheduler states loaded here.
|
||||
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||
kwargs:
|
||||
Additional keyword arguments used to hide deprecated arguments
|
||||
"""
|
||||
if "model_path" in kwargs:
|
||||
resume_from_checkpoint = kwargs.pop("model_path")
|
||||
warnings.warn(
|
||||
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
|
||||
"instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(kwargs) > 0:
|
||||
raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
|
||||
# This might change the seed so needs to run first.
|
||||
self._hp_search_setup(trial)
|
||||
|
||||
@@ -701,13 +717,13 @@ class Trainer:
|
||||
self.optimizer, self.lr_scheduler = None, None
|
||||
|
||||
# Load potential model checkpoint
|
||||
if model_path is not None and os.path.isfile(os.path.join(model_path, WEIGHTS_NAME)):
|
||||
logger.info(f"Loading model from {model_path}).")
|
||||
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(model_path)
|
||||
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
||||
model_reloaded = True
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME))
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
@@ -757,7 +773,7 @@ class Trainer:
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(model_path)
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
|
||||
model = self.model_wrapped
|
||||
|
||||
@@ -827,8 +843,10 @@ class Trainer:
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
||||
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
||||
if resume_from_checkpoint is not None and os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, "trainer_state.json")
|
||||
):
|
||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
|
||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||
if not self.args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
@@ -1102,20 +1120,20 @@ class Trainer:
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, model_path):
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
if model_path is None:
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(model_path, "scheduler.pt")
|
||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||
os.path.join(checkpoint, "scheduler.pt")
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
@@ -1125,15 +1143,15 @@ class Trainer:
|
||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
if self.deepspeed:
|
||||
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
|
||||
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True)
|
||||
self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
|
||||
@@ -341,20 +341,20 @@ def main():
|
||||
if training_args.do_train:
|
||||
{%- if cookiecutter.can_train_from_scratch == "False" %}
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
checkpoint = None
|
||||
{%- elif cookiecutter.can_train_from_scratch == "True" %}
|
||||
if last_checkpoint is not None:
|
||||
model_path = last_checkpoint
|
||||
checkpoint = last_checkpoint
|
||||
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
|
||||
model_path = model_args.model_name_or_path
|
||||
checkpoint = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = None
|
||||
checkpoint = None
|
||||
{% endif %}
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
|
||||
@@ -581,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -594,7 +594,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Reinitialize trainer and load model
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -617,7 +617,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -632,7 +632,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
@@ -670,7 +670,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
learning_rate=0.1,
|
||||
)
|
||||
|
||||
trainer.train(model_path=checkpoint)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
|
||||
Reference in New Issue
Block a user