Clean the Trainer state (#7490)
* Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Add test of resumed training * Fixes * Non multiGPU test * Clean Trainer state * Add more to the state * Documentation * One last test * Make resume training test more complete * Unwanted changes
This commit is contained in:
@@ -201,7 +201,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
|||||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_utils import EvalPrediction, set_seed
|
from .trainer_utils import EvalPrediction, TrainerState, set_seed
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -260,10 +259,11 @@ class Trainer:
|
|||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
self.log_history = []
|
|
||||||
if "prediction_loss_only" in kwargs:
|
if "prediction_loss_only" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
|
||||||
|
+ "future version. Use `args.prediction_loss_only` instead. Setting "
|
||||||
|
+ f"`args.prediction_loss_only={kwargs['prediction_loss_only']}",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
||||||
@@ -302,19 +302,20 @@ class Trainer:
|
|||||||
if isinstance(eval_dataset, datasets.Dataset):
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||||
|
|
||||||
self.global_step = None
|
self.state = TrainerState()
|
||||||
self.epoch = None
|
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||||
self.total_flos = None
|
# state at each call to self.log.
|
||||||
|
self._total_flos = None
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
self.use_tune_checkpoints = False
|
self.use_tune_checkpoints = False
|
||||||
if self.args.label_names is None:
|
default_label_names = (
|
||||||
self.args.label_names = (
|
["start_positions, end_positions"]
|
||||||
["start_positions, end_positions"]
|
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
|
||||||
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
|
else ["labels"]
|
||||||
else ["labels"]
|
)
|
||||||
)
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
|
|
||||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||||
if not self.args.remove_unused_columns:
|
if not self.args.remove_unused_columns:
|
||||||
@@ -588,16 +589,16 @@ class Trainer:
|
|||||||
if trial.should_prune():
|
if trial.should_prune():
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
if self.global_step % self.args.save_steps == 0:
|
if self.state.global_step % self.args.save_steps == 0:
|
||||||
self._tune_save_checkpoint()
|
self._tune_save_checkpoint()
|
||||||
tune.report(objective=self.objective, **metrics)
|
tune.report(objective=self.objective, **metrics)
|
||||||
|
|
||||||
def _tune_save_checkpoint(self):
|
def _tune_save_checkpoint(self):
|
||||||
if not self.use_tune_checkpoints:
|
if not self.use_tune_checkpoints:
|
||||||
return
|
return
|
||||||
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
|
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||||
self.args.output_dir = checkpoint_dir
|
self.args.output_dir = checkpoint_dir
|
||||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
if self.is_world_master():
|
if self.is_world_master():
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
@@ -632,16 +633,16 @@ class Trainer:
|
|||||||
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
t_total = self.args.max_steps
|
max_steps = self.args.max_steps
|
||||||
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
||||||
self.args.max_steps % num_update_steps_per_epoch > 0
|
self.args.max_steps % num_update_steps_per_epoch > 0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
||||||
num_train_epochs = self.args.num_train_epochs
|
num_train_epochs = self.args.num_train_epochs
|
||||||
self.args.max_steps = t_total
|
num_train_epochs = int(np.ceil(num_train_epochs))
|
||||||
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
@@ -658,17 +659,14 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
# Check if a saved Trainer state exist
|
# Moxed precision training with apex (torch < 1.6)
|
||||||
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
|
||||||
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.args.fp16 and _use_apex:
|
if self.args.fp16 and _use_apex:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# Multi-gpu training (should be after apex fp16 initialization)
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
@@ -706,37 +704,35 @@ class Trainer:
|
|||||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", max_steps)
|
||||||
|
|
||||||
self.global_step = 0
|
self.state.epoch = 0
|
||||||
self.epoch = 0
|
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
|
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if model_path is not None:
|
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
||||||
# set global_step to global_step of last saved checkpoint from model path
|
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
||||||
try:
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||||
|
|
||||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
|
logger.info(" Continuing training from global step %d", self.state.global_step)
|
||||||
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
|
|
||||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
# to set this after the load.
|
||||||
logger.info(" Continuing training from global step %d", self.global_step)
|
self.state.max_steps = max_steps
|
||||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
self.state.num_train_epochs = num_train_epochs
|
||||||
except ValueError:
|
|
||||||
self.global_step = 0
|
|
||||||
logger.info(" Starting fine-tuning.")
|
|
||||||
|
|
||||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||||
self.total_flos = self.state.total_flos
|
self._total_flos = self.state.total_flos
|
||||||
logging_loss_scalar = 0.0
|
logging_loss_scalar = 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||||
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
|
||||||
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
@@ -762,7 +758,7 @@ class Trainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
self.total_flos += self.floating_point_ops(inputs)
|
self._total_flos += self.floating_point_ops(inputs)
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
@@ -787,11 +783,11 @@ class Trainer:
|
|||||||
|
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
self.global_step += 1
|
self.state.global_step += 1
|
||||||
self.epoch = epoch + (step + 1) / len(epoch_iterator)
|
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
||||||
|
|
||||||
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
|
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
|
||||||
self.global_step == 1 and self.args.logging_first_step
|
self.state.global_step == 1 and self.args.logging_first_step
|
||||||
):
|
):
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
tr_loss_scalar = tr_loss.item()
|
tr_loss_scalar = tr_loss.item()
|
||||||
@@ -808,7 +804,7 @@ class Trainer:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||||
and self.global_step % self.args.eval_steps == 0
|
and self.state.global_step % self.args.eval_steps == 0
|
||||||
):
|
):
|
||||||
metrics = self.evaluate()
|
metrics = self.evaluate()
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
@@ -818,12 +814,12 @@ class Trainer:
|
|||||||
if (
|
if (
|
||||||
not self.args.load_best_model_at_end
|
not self.args.load_best_model_at_end
|
||||||
and self.args.save_steps > 0
|
and self.args.save_steps > 0
|
||||||
and self.global_step % self.args.save_steps == 0
|
and self.state.global_step % self.args.save_steps == 0
|
||||||
):
|
):
|
||||||
self._save_training(model, trial)
|
self._save_training(model, trial)
|
||||||
|
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
if self.state.global_step >= max_steps:
|
||||||
break
|
break
|
||||||
epoch_pbar.close()
|
epoch_pbar.close()
|
||||||
train_pbar.update(1)
|
train_pbar.update(1)
|
||||||
@@ -843,7 +839,7 @@ class Trainer:
|
|||||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
||||||
"configured. Check your training configuration if this is unexpected."
|
"configured. Check your training configuration if this is unexpected."
|
||||||
)
|
)
|
||||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
if self.state.global_step >= max_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
train_pbar.close()
|
train_pbar.close()
|
||||||
@@ -865,7 +861,7 @@ class Trainer:
|
|||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||||
|
|
||||||
def _save_training(self, model, trial, metrics=None):
|
def _save_training(self, model, trial, metrics=None):
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
# In all cases (even distributed/parallel), self.model is always a reference
|
||||||
@@ -875,7 +871,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
if self.hp_search_backend is not None and trial is not None:
|
if self.hp_search_backend is not None and trial is not None:
|
||||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||||
checkpoint_folder += f"-run-{run_id}"
|
checkpoint_folder += f"-run-{run_id}"
|
||||||
@@ -1022,22 +1018,15 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
return self._log(logs, iterator=iterator)
|
return self._log(logs, iterator=iterator)
|
||||||
|
|
||||||
if self.epoch is not None:
|
if self.state.epoch is not None:
|
||||||
logs["epoch"] = self.epoch
|
logs["epoch"] = self.state.epoch
|
||||||
if self.total_flos is not None:
|
if self._total_flos is not None:
|
||||||
if self.args.local_rank != -1:
|
self.store_flos()
|
||||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
logs["total_flos"] = self.state.total_flos
|
||||||
else:
|
|
||||||
total_flos = self.total_flos
|
|
||||||
if total_flos > 0:
|
|
||||||
logs["total_flos"] = total_flos
|
|
||||||
if self.global_step is None:
|
|
||||||
# when logging evaluation metrics without training
|
|
||||||
self.global_step = 0
|
|
||||||
if self.tb_writer:
|
if self.tb_writer:
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
self.tb_writer.add_scalar(k, v, self.global_step)
|
self.tb_writer.add_scalar(k, v, self.state.global_step)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Trainer is attempting to log a value of "
|
"Trainer is attempting to log a value of "
|
||||||
@@ -1051,15 +1040,16 @@ class Trainer:
|
|||||||
self.tb_writer.flush()
|
self.tb_writer.flush()
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
wandb.log(logs, step=self.global_step)
|
wandb.log(logs, step=self.state.global_step)
|
||||||
if is_comet_available():
|
if is_comet_available():
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
experiment = comet_ml.config.get_global_experiment()
|
experiment = comet_ml.config.get_global_experiment()
|
||||||
if experiment is not None:
|
if experiment is not None:
|
||||||
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
experiment._log_metrics(
|
||||||
output = {**logs, **{"step": self.global_step}}
|
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
|
||||||
if self.is_world_process_zero():
|
)
|
||||||
self.log_history.append(output)
|
output = {**logs, **{"step": self.state.global_step}}
|
||||||
|
self.state.log_history.append(output)
|
||||||
if iterator is not None:
|
if iterator is not None:
|
||||||
iterator.write(output)
|
iterator.write(output)
|
||||||
else:
|
else:
|
||||||
@@ -1205,9 +1195,6 @@ class Trainer:
|
|||||||
if xm.is_master_ordinal():
|
if xm.is_master_ordinal():
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
json.dump(
|
|
||||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
@@ -1238,17 +1225,14 @@ class Trainer:
|
|||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
json.dump(
|
|
||||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def store_flos(self):
|
def store_flos(self):
|
||||||
# Storing the number of floating-point operations that went into the model
|
# Storing the number of floating-point operations that went into the model
|
||||||
if self.total_flos is not None:
|
if self._total_flos is not None:
|
||||||
if self.args.local_rank != -1:
|
if self.args.local_rank != -1:
|
||||||
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
|
||||||
else:
|
else:
|
||||||
self.state.total_flos = self.total_flos
|
self.state.total_flos = self._total_flos
|
||||||
|
|
||||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||||
ordering_and_checkpoint_path = []
|
ordering_and_checkpoint_path = []
|
||||||
@@ -1466,7 +1450,7 @@ class Trainer:
|
|||||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
A tuple with the loss, logits and labels (each being optional).
|
A tuple with the loss, logits and labels (each being optional).
|
||||||
"""
|
"""
|
||||||
has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
|
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -1490,7 +1474,7 @@ class Trainer:
|
|||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
|
|
||||||
if has_labels:
|
if has_labels:
|
||||||
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
|
labels = tuple(inputs.get(name).detach() for name in self.label_names)
|
||||||
if len(labels) == 1:
|
if len(labels) == 1:
|
||||||
labels = labels[0]
|
labels = labels[0]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -221,13 +221,46 @@ def distributed_broadcast_scalars(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TrainerState:
|
class TrainerState:
|
||||||
"""
|
"""
|
||||||
A class containing the `Trainer` fields that will be saved along the model and optimizer.
|
A class containing the `Trainer` inner state that will be saved along the model and optimizer.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
|
||||||
|
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
|
||||||
|
then one update step requires going throuch `n` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (:obj:`float`, `optional`):
|
||||||
|
Only set during training, will represent the epoch the training is at (the decimal part being the
|
||||||
|
percentage of the current epoch completed).
|
||||||
|
global_step (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
During training, represents the number of update steps completed.
|
||||||
|
max_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
The number of update steps to do during the current training.
|
||||||
|
total_flos (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
The total number of floating operations done by the model since the beginning of training.
|
||||||
|
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
||||||
|
The list of logs done since the beginning of training.
|
||||||
|
best_metric (:obj:`float`, `optional`):
|
||||||
|
When tracking the best model, the value of the best metric encountered so far.
|
||||||
|
best_model_checkpoint (:obj:`str`, `optional`):
|
||||||
|
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
||||||
|
far.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
epoch: Optional[float] = None
|
||||||
|
global_step: int = 0
|
||||||
|
max_steps: int = 0
|
||||||
|
num_train_epochs: int = 0
|
||||||
total_flos: int = 0
|
total_flos: int = 0
|
||||||
|
log_history: List[Dict[str, float]] = None
|
||||||
best_metric: Optional[float] = None
|
best_metric: Optional[float] = None
|
||||||
best_model_checkpoint: Optional[str] = None
|
best_model_checkpoint: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.log_history is None:
|
||||||
|
self.log_history = []
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
||||||
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import json
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
|||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
|
TrainerState,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +156,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(model.b, b))
|
self.assertTrue(torch.allclose(model.b, b))
|
||||||
|
|
||||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||||
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
|
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||||
if is_pretrained:
|
if is_pretrained:
|
||||||
file_list.append("config.json")
|
file_list.append("config.json")
|
||||||
for step in range(freq, total, freq):
|
for step in range(freq, total, freq):
|
||||||
@@ -168,7 +169,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
||||||
):
|
):
|
||||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||||
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
|
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
||||||
|
|
||||||
values = [d[metric] for d in log_history]
|
values = [d[metric] for d in log_history]
|
||||||
best_value = max(values) if greater_is_better else min(values)
|
best_value = max(values) if greater_is_better else min(values)
|
||||||
@@ -188,6 +189,12 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
self.assertEqual(metrics[metric], best_value)
|
self.assertEqual(metrics[metric], best_value)
|
||||||
|
|
||||||
|
def test_training_arguments_are_left_untouched(self):
|
||||||
|
trainer = get_regression_trainer()
|
||||||
|
trainer.train()
|
||||||
|
args = TrainingArguments("./regression")
|
||||||
|
self.assertEqual(args.to_dict(), trainer.args.to_dict())
|
||||||
|
|
||||||
def test_reproducible_training(self):
|
def test_reproducible_training(self):
|
||||||
# Checks that training worked, model trained and seed made a reproducible training.
|
# Checks that training worked, model trained and seed made a reproducible training.
|
||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
@@ -368,6 +375,55 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||||
|
|
||||||
|
def test_can_resume_training(self):
|
||||||
|
if torch.cuda.device_count() > 2:
|
||||||
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||||
|
# won't be the same since the training dataloader is shuffled).
|
||||||
|
return
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||||
|
trainer.train()
|
||||||
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state = dataclasses.asdict(trainer.state)
|
||||||
|
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
|
# Reinitialize trainer and load model
|
||||||
|
model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||||
|
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||||
|
|
||||||
|
trainer.train(model_path=checkpoint)
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(a, a1)
|
||||||
|
self.assertEqual(b, b1)
|
||||||
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
|
# With a regular model that is not a PreTrainedModel
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state = dataclasses.asdict(trainer.state)
|
||||||
|
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
|
# Reinitialize trainer and load model
|
||||||
|
model = RegressionModel()
|
||||||
|
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset)
|
||||||
|
|
||||||
|
trainer.train(model_path=checkpoint)
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(a, a1)
|
||||||
|
self.assertEqual(b, b1)
|
||||||
|
self.assertEqual(state, state1)
|
||||||
|
|
||||||
def test_load_best_model_at_end(self):
|
def test_load_best_model_at_end(self):
|
||||||
total = int(self.n_epochs * 64 / self.batch_size)
|
total = int(self.n_epochs * 64 / self.batch_size)
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
|||||||
Reference in New Issue
Block a user