From 281c0c8b5b82d5b94e119ee58e7f82871c2653bc Mon Sep 17 00:00:00 2001 From: hsilva664 Date: Wed, 12 Feb 2025 11:48:16 -0300 Subject: [PATCH] adding option to save/reload scaler (#34932) * Adding option to save/reload scaler * Removing duplicate variable * Adding save/reload test * Small fixes on deterministic algorithm call * Moving LLM test to another file to isolate its environment * Moving back to old file and using subprocess to run test isolated * Reverting back accidental change * Reverting back accidental change --- src/transformers/trainer.py | 45 +++++++++++++++++- tests/trainer/test_trainer.py | 88 ++++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9c8d5a03a8..677daa55ee 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -305,9 +305,9 @@ logger = logging.get_logger(__name__) TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" +SCALER_NAME = "scaler.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" SCHEDULER_NAME = "scheduler.pt" -SCALER_NAME = "scaler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" @@ -2394,6 +2394,7 @@ class Trainer: # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) + self._load_scaler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model @@ -3191,6 +3192,7 @@ class Trainer: if not self.args.save_only_model: # Save optimizer and scheduler self._save_optimizer_and_scheduler(output_dir) + self._save_scaler(output_dir) # Save RNG state self._save_rng_state(output_dir) @@ -3424,6 +3426,47 @@ class Trainer: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) + def _save_scaler(self, output_dir): + # See if there is a scaler attribute + try: + scaler = self.accelerator.scaler + except AttributeError: + return + if scaler is None: + return + if is_torch_xla_available(): + xm.rendezvous("saving_scaler_state") + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + reissue_pt_warnings(caught_warnings) + + # Save SCALER + if self.args.should_save and not is_torch_xla_available(): + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + reissue_pt_warnings(caught_warnings) + + def _load_scaler(self, checkpoint): + """If scaler state exists, load it.""" + if checkpoint is None: + return + + checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME)) + + if checkpoint_file_exists: + # On TPU we have to take some extra precautions to properly load the states on the right device. + # Load in scaler states + if is_torch_xla_available(): + with warnings.catch_warnings(record=True) as caught_warnings: + scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + xm.send_cpu_data_to_device(scaler_state, self.args.device) + self.accelerator.scaler.load_state_dict(scaler_state) + else: + with warnings.catch_warnings(record=True) as caught_warnings: + self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) + reissue_pt_warnings(caught_warnings) + def _load_callback_state(self): """If callback states exist and were passed in, restore their states if enabled""" if not self.args.restore_callback_states_from_checkpoint: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3ac2566041..3abf3eaee6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -46,6 +46,7 @@ from transformers import ( PretrainedConfig, TrainerCallback, TrainingArguments, + enable_full_determinism, get_polynomial_decay_schedule_with_warmup, is_torch_available, logging, @@ -97,6 +98,7 @@ from transformers.testing_utils import ( require_torchdynamo, require_vision, require_wandb, + run_test_using_subprocess, slow, torch_device, ) @@ -576,13 +578,41 @@ if is_torch_available(): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + def get_language_model_trainer(**kwargs): + import datasets + + dataset = datasets.load_dataset("fka/awesome-chatgpt-prompts") + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token = tokenizer.eos_token + + def _tokenize_function(examples): + model_inputs = tokenizer(examples["prompt"], padding="max_length", truncation=True) + model_inputs["labels"] = np.array(model_inputs["input_ids"]).astype(np.int64) + return model_inputs + + tokenized_datasets = dataset.map(_tokenize_function, batched=True) + training_args = TrainingArguments(**kwargs) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + ) + + return trainer + class TrainerIntegrationCommon: - def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): + def check_saved_checkpoints( + self, output_dir, freq, total, is_pretrained=True, safe_weights=True, use_scaler=False + ): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] if is_pretrained: file_list.append("config.json") + if use_scaler: + file_list.append("scaler.pt") for step in range(freq, total, freq): checkpoint = os.path.join(output_dir, f"checkpoint-{step}") self.assertTrue(os.path.isdir(checkpoint)) @@ -3095,6 +3125,62 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train(resume_from_checkpoint=True) self.assertTrue("No valid checkpoint found in output directory" in str(context.exception)) + # require_torch_non_multi_accelerator is necessary because this worker blocks runs when using multiple GPUs, making + # the test slower. + @require_torch_non_multi_accelerator + @run_test_using_subprocess + @slow + def test_can_resume_training_lm(self): + # Check if it works for a simple language modeling example + training_steps = 10 + resume_from_step = 8 + with tempfile.TemporaryDirectory() as tmpdir: + enable_full_determinism(0) + kwargs = { + "output_dir": tmpdir, + "fp16": True, + "max_steps": training_steps, + "per_device_train_batch_size": 1, + "learning_rate": 1e-5, + "lr_scheduler_type": "cosine", + "save_strategy": "steps", + "save_steps": 1, + "logging_strategy": "steps", + "logging_steps": 1, + "report_to": "none", + } + + trainer = get_language_model_trainer(**kwargs) + trainer.train(resume_from_checkpoint=False) + # Get the parameter length of the model + model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()]) + model_param_len = len(model_params) + # Sample uniform indexes and save the values of the parameters (considering an unrolled vector with + # all of them) + indices = torch.randint(0, model_param_len, (1000,)) + # Save the values of the parameters for later comparison + model_params_sample = model_params[indices].detach().clone() + state1 = dataclasses.asdict(trainer.state) + # Delete the reference + del model_params, trainer + # Checks if all checkpoints are there, +1 is necessary because range is 1-indexed + self.check_saved_checkpoints( + tmpdir, freq=1, total=training_steps + 1, is_pretrained=True, safe_weights=True, use_scaler=True + ) + + # Checkpoint at intermediate step + enable_full_determinism(0) + checkpoint = os.path.join(tmpdir, f"checkpoint-{resume_from_step+1}") + trainer = get_language_model_trainer(**kwargs) + trainer.train(resume_from_checkpoint=checkpoint) + model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()]) + + # Check that the parameters are the same + self.assertTrue(torch.allclose(model_params[indices], model_params_sample)) + state2 = dataclasses.asdict(trainer.state) + self.check_trainer_state_are_the_same(state1, state2) + del model_params, trainer + @unittest.skip( reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`." )