From 29baa8fabe15393ec4451beceee6d025881ec992 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 1 Oct 2020 13:07:04 -0400 Subject: [PATCH] Clean the Trainer state (#7490) * Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Trainer should not modify its TrainingArguments * Add test of resumed training * Fixes * Non multiGPU test * Clean Trainer state * Add more to the state * Documentation * One last test * Make resume training test more complete * Unwanted changes --- src/transformers/__init__.py | 2 +- src/transformers/trainer.py | 150 +++++++++++++----------------- src/transformers/trainer_utils.py | 35 ++++++- tests/test_trainer.py | 62 +++++++++++- 4 files changed, 161 insertions(+), 88 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 84f86bee55..a999ba9c60 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -201,7 +201,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer # Trainer -from .trainer_utils import EvalPrediction, set_seed +from .trainer_utils import EvalPrediction, TrainerState, set_seed from .training_args import TrainingArguments from .training_args_tf import TFTrainingArguments from .utils import logging diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9a78ae61d1..b13f9dbc19 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1,5 +1,4 @@ import inspect -import json import math import os import re @@ -260,10 +259,11 @@ class Trainer: "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) self.tb_writer = tb_writer - self.log_history = [] if "prediction_loss_only" in kwargs: warnings.warn( - "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.", + "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a " + + "future version. Use `args.prediction_loss_only` instead. Setting " + + f"`args.prediction_loss_only={kwargs['prediction_loss_only']}", FutureWarning, ) self.args.prediction_loss_only = kwargs.pop("prediction_loss_only") @@ -302,19 +302,20 @@ class Trainer: if isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(self.eval_dataset, description="evaluation") - self.global_step = None - self.epoch = None - self.total_flos = None + self.state = TrainerState() + # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the + # state at each call to self.log. + self._total_flos = None if self.args.fp16 and _use_native_amp: self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None self.use_tune_checkpoints = False - if self.args.label_names is None: - self.args.label_names = ( - ["start_positions, end_positions"] - if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() - else ["labels"] - ) + default_label_names = ( + ["start_positions, end_positions"] + if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() + else ["labels"] + ) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: @@ -588,16 +589,16 @@ class Trainer: if trial.should_prune(): raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: - if self.global_step % self.args.save_steps == 0: + if self.state.global_step % self.args.save_steps == 0: self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) def _tune_save_checkpoint(self): if not self.use_tune_checkpoints: return - with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir: + with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: self.args.output_dir = checkpoint_dir - output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") + output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir) if self.is_world_master(): torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) @@ -632,16 +633,16 @@ class Trainer: num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) if self.args.max_steps > 0: - t_total = self.args.max_steps + max_steps = self.args.max_steps num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( self.args.max_steps % num_update_steps_per_epoch > 0 ) else: - t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs) + max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs - self.args.max_steps = t_total + num_train_epochs = int(np.ceil(num_train_epochs)) - self.create_optimizer_and_scheduler(num_training_steps=t_total) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() # Check if saved optimizer or scheduler states exist @@ -658,17 +659,14 @@ class Trainer: self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) reissue_pt_warnings(caught_warnings) - # Check if a saved Trainer state exist - if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")): - self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) - + # Moxed precision training with apex (torch < 1.6) model = self.model if self.args.fp16 and _use_apex: if not is_apex_available(): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) - # multi-gpu training (should be after apex fp16 initialization) + # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) @@ -706,37 +704,35 @@ class Trainer: logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) + logger.info(" Total optimization steps = %d", max_steps) - self.global_step = 0 - self.epoch = 0 + self.state.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint - if model_path is not None: - # set global_step to global_step of last saved checkpoint from model path - try: - self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) + if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): + self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - epochs_trained = self.global_step // num_update_steps_per_epoch - steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", self.state.global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(" Continuing training from epoch %d", epochs_trained) - logger.info(" Continuing training from global step %d", self.global_step) - logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) - except ValueError: - self.global_step = 0 - logger.info(" Starting fine-tuning.") + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs tr_loss = torch.tensor(0.0).to(self.args.device) - self.total_flos = self.state.total_flos + self._total_flos = self.state.total_flos logging_loss_scalar = 0.0 model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() - train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) - for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))): + train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm) + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -762,7 +758,7 @@ class Trainer: continue tr_loss += self.training_step(model, inputs) - self.total_flos += self.floating_point_ops(inputs) + self._total_flos += self.floating_point_ops(inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps @@ -787,11 +783,11 @@ class Trainer: self.lr_scheduler.step() model.zero_grad() - self.global_step += 1 - self.epoch = epoch + (step + 1) / len(epoch_iterator) + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / len(epoch_iterator) - if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( - self.global_step == 1 and self.args.logging_first_step + if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or ( + self.state.global_step == 1 and self.args.logging_first_step ): logs: Dict[str, float] = {} tr_loss_scalar = tr_loss.item() @@ -808,7 +804,7 @@ class Trainer: if ( self.args.evaluation_strategy == EvaluationStrategy.STEPS - and self.global_step % self.args.eval_steps == 0 + and self.state.global_step % self.args.eval_steps == 0 ): metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) @@ -818,12 +814,12 @@ class Trainer: if ( not self.args.load_best_model_at_end and self.args.save_steps > 0 - and self.global_step % self.args.save_steps == 0 + and self.state.global_step % self.args.save_steps == 0 ): self._save_training(model, trial) epoch_pbar.update(1) - if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: + if self.state.global_step >= max_steps: break epoch_pbar.close() train_pbar.update(1) @@ -843,7 +839,7 @@ class Trainer: "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) - if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: + if self.state.global_step >= max_steps: break train_pbar.close() @@ -865,7 +861,7 @@ class Trainer: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) - return TrainOutput(self.global_step, tr_loss.item() / self.global_step) + return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) def _save_training(self, model, trial, metrics=None): # In all cases (even distributed/parallel), self.model is always a reference @@ -875,7 +871,7 @@ class Trainer: else: assert model is self.model, f"Model {model} should be a reference to self.model" # Save model checkpoint - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is not None and trial is not None: run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() checkpoint_folder += f"-run-{run_id}" @@ -1022,22 +1018,15 @@ class Trainer: ) return self._log(logs, iterator=iterator) - if self.epoch is not None: - logs["epoch"] = self.epoch - if self.total_flos is not None: - if self.args.local_rank != -1: - total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() - else: - total_flos = self.total_flos - if total_flos > 0: - logs["total_flos"] = total_flos - if self.global_step is None: - # when logging evaluation metrics without training - self.global_step = 0 + if self.state.epoch is not None: + logs["epoch"] = self.state.epoch + if self._total_flos is not None: + self.store_flos() + logs["total_flos"] = self.state.total_flos if self.tb_writer: for k, v in logs.items(): if isinstance(v, (int, float)): - self.tb_writer.add_scalar(k, v, self.global_step) + self.tb_writer.add_scalar(k, v, self.state.global_step) else: logger.warning( "Trainer is attempting to log a value of " @@ -1051,15 +1040,16 @@ class Trainer: self.tb_writer.flush() if is_wandb_available(): if self.is_world_process_zero(): - wandb.log(logs, step=self.global_step) + wandb.log(logs, step=self.state.global_step) if is_comet_available(): if self.is_world_process_zero(): experiment = comet_ml.config.get_global_experiment() if experiment is not None: - experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers") - output = {**logs, **{"step": self.global_step}} - if self.is_world_process_zero(): - self.log_history.append(output) + experiment._log_metrics( + logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers" + ) + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) if iterator is not None: iterator.write(output) else: @@ -1205,9 +1195,6 @@ class Trainer: if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) - json.dump( - self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False - ) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -1238,17 +1225,14 @@ class Trainer: # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin")) - json.dump( - self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False - ) def store_flos(self): # Storing the number of floating-point operations that went into the model - if self.total_flos is not None: + if self._total_flos is not None: if self.args.local_rank != -1: - self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() + self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item() else: - self.state.total_flos = self.total_flos + self.state.total_flos = self._total_flos def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] @@ -1466,7 +1450,7 @@ class Trainer: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ - has_labels = all(inputs.get(k) is not None for k in self.args.label_names) + has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) with torch.no_grad(): @@ -1490,7 +1474,7 @@ class Trainer: logits = logits[0] if has_labels: - labels = tuple(inputs.get(name).detach() for name in self.args.label_names) + labels = tuple(inputs.get(name).detach() for name in self.label_names) if len(labels) == 1: labels = labels[0] else: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 334a8e4d95..b3207ec359 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -221,13 +221,46 @@ def distributed_broadcast_scalars( @dataclass class TrainerState: """ - A class containing the `Trainer` fields that will be saved along the model and optimizer. + A class containing the `Trainer` inner state that will be saved along the model and optimizer. + + .. note:: + + In all this class, one step is to be understood as one update step. When using gradient accumulation, one + update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`, + then one update step requires going throuch `n` batches. + + Args: + epoch (:obj:`float`, `optional`): + Only set during training, will represent the epoch the training is at (the decimal part being the + percentage of the current epoch completed). + global_step (:obj:`int`, `optional`, defaults to 0): + During training, represents the number of update steps completed. + max_steps (:obj:`int`, `optional`, defaults to 0): + The number of update steps to do during the current training. + total_flos (:obj:`int`, `optional`, defaults to 0): + The total number of floating operations done by the model since the beginning of training. + log_history (:obj:`List[Dict[str, float]]`, `optional`): + The list of logs done since the beginning of training. + best_metric (:obj:`float`, `optional`): + When tracking the best model, the value of the best metric encountered so far. + best_model_checkpoint (:obj:`str`, `optional`): + When tracking the best model, the value of the name of the checkpoint for the best model encountered so + far. """ + epoch: Optional[float] = None + global_step: int = 0 + max_steps: int = 0 + num_train_epochs: int = 0 total_flos: int = 0 + log_history: List[Dict[str, float]] = None best_metric: Optional[float] = None best_model_checkpoint: Optional[str] = None + def __post_init__(self): + if self.log_history is None: + self.log_history = [] + def save_to_json(self, json_path: str): """ Save the content of this instance in JSON format inside :obj:`json_path`.""" json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 18fc6551f9..a758ced241 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,4 +1,4 @@ -import json +import dataclasses import os import tempfile import unittest @@ -22,6 +22,7 @@ if is_torch_available(): LineByLineTextDataset, PreTrainedModel, Trainer, + TrainerState, ) @@ -155,7 +156,7 @@ class TrainerIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(model.b, b)) def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True): - file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"] + file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): @@ -168,7 +169,7 @@ class TrainerIntegrationTest(unittest.TestCase): self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True ): checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") - log_history = json.load(open(os.path.join(checkpoint, "log_history.json"))) + log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history values = [d[metric] for d in log_history] best_value = max(values) if greater_is_better else min(values) @@ -188,6 +189,12 @@ class TrainerIntegrationTest(unittest.TestCase): metrics = trainer.evaluate() self.assertEqual(metrics[metric], best_value) + def test_training_arguments_are_left_untouched(self): + trainer = get_regression_trainer() + trainer.train() + args = TrainingArguments("./regression") + self.assertEqual(args.to_dict(), trainer.args.to_dict()) + def test_reproducible_training(self): # Checks that training worked, model trained and seed made a reproducible training. trainer = get_regression_trainer(learning_rate=0.1) @@ -368,6 +375,55 @@ class TrainerIntegrationTest(unittest.TestCase): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) + def test_can_resume_training(self): + if torch.cuda.device_count() > 2: + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + return + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + # Reinitialize trainer and load model + model = RegressionPreTrainedModel.from_pretrained(checkpoint) + trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + + trainer.train(model_path=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.assertEqual(state, state1) + + # With a regular model that is not a PreTrainedModel + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False + ) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + # Reinitialize trainer and load model + model = RegressionModel() + state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) + model.load_state_dict(state_dict) + trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + + trainer.train(model_path=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.assertEqual(state, state1) + def test_load_best_model_at_end(self): total = int(self.n_epochs * 64 / self.batch_size) with tempfile.TemporaryDirectory() as tmpdir: