Clean the Trainer state (#7490)

* Trainer should not modify its TrainingArguments

* Trainer should not modify its TrainingArguments

* Trainer should not modify its TrainingArguments

* Add test of resumed training

* Fixes

* Non multiGPU test

* Clean Trainer state

* Add more to the state

* Documentation

* One last test

* Make resume training test more complete

* Unwanted changes
This commit is contained in:
Sylvain Gugger
2020-10-01 13:07:04 -04:00
committed by GitHub
parent 2a358f45ef
commit 29baa8fabe
4 changed files with 161 additions and 88 deletions

View File

@@ -1,5 +1,4 @@
import inspect
import json
import math
import os
import re
@@ -260,10 +259,11 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
+ "future version. Use `args.prediction_loss_only` instead. Setting "
+ f"`args.prediction_loss_only={kwargs['prediction_loss_only']}",
FutureWarning,
)
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
@@ -302,19 +302,20 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")
self.global_step = None
self.epoch = None
self.total_flos = None
self.state = TrainerState()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
self.use_tune_checkpoints = False
if self.args.label_names is None:
self.args.label_names = (
["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"]
)
default_label_names = (
["start_positions, end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
@@ -588,16 +589,16 @@ class Trainer:
if trial.should_prune():
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
if self.global_step % self.args.save_steps == 0:
if self.state.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_master():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
@@ -632,16 +633,16 @@ class Trainer:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
t_total = self.args.max_steps
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
self.args.max_steps = t_total
num_train_epochs = int(np.ceil(num_train_epochs))
self.create_optimizer_and_scheduler(num_training_steps=t_total)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
# Check if saved optimizer or scheduler states exist
@@ -658,17 +659,14 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
# Check if a saved Trainer state exist
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
# Moxed precision training with apex (torch < 1.6)
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
@@ -706,37 +704,35 @@ class Trainer:
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
logger.info(" Total optimization steps = %d", max_steps)
self.global_step = 0
self.epoch = 0
self.state.epoch = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
epochs_trained = self.global_step // num_update_steps_per_epoch
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.state.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
logger.info(" Starting fine-tuning.")
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
tr_loss = torch.tensor(0.0).to(self.args.device)
self.total_flos = self.state.total_flos
self._total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
@@ -762,7 +758,7 @@ class Trainer:
continue
tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)
self._total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
@@ -787,11 +783,11 @@ class Trainer:
self.lr_scheduler.step()
model.zero_grad()
self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
self.state.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
@@ -808,7 +804,7 @@ class Trainer:
if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.global_step % self.args.eval_steps == 0
and self.state.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
@@ -818,12 +814,12 @@ class Trainer:
if (
not self.args.load_best_model_at_end
and self.args.save_steps > 0
and self.global_step % self.args.save_steps == 0
and self.state.global_step % self.args.save_steps == 0
):
self._save_training(model, trial)
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
if self.state.global_step >= max_steps:
break
epoch_pbar.close()
train_pbar.update(1)
@@ -843,7 +839,7 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
if self.state.global_step >= max_steps:
break
train_pbar.close()
@@ -865,7 +861,7 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _save_training(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
@@ -875,7 +871,7 @@ class Trainer:
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
@@ -1022,22 +1018,15 @@ class Trainer:
)
return self._log(logs, iterator=iterator)
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
if self.tb_writer:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, self.global_step)
self.tb_writer.add_scalar(k, v, self.state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
@@ -1051,15 +1040,16 @@ class Trainer:
self.tb_writer.flush()
if is_wandb_available():
if self.is_world_process_zero():
wandb.log(logs, step=self.global_step)
wandb.log(logs, step=self.state.global_step)
if is_comet_available():
if self.is_world_process_zero():
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
experiment._log_metrics(
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
@@ -1205,9 +1195,6 @@ class Trainer:
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
@@ -1238,17 +1225,14 @@ class Trainer:
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self._total_flos is not None:
if self.args.local_rank != -1:
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
else:
self.state.total_flos = self.total_flos
self.state.total_flos = self._total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
@@ -1466,7 +1450,7 @@ class Trainer:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
@@ -1490,7 +1474,7 @@ class Trainer:
logits = logits[0]
if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
labels = tuple(inputs.get(name).detach() for name in self.label_names)
if len(labels) == 1:
labels = labels[0]
else: