adding option to save/reload scaler (#34932)

* Adding option to save/reload scaler

* Removing duplicate variable

* Adding save/reload test

* Small fixes on deterministic algorithm call

* Moving LLM test to another file to isolate its environment

* Moving back to old file and using subprocess to run test isolated

* Reverting back accidental change

* Reverting back accidental change
This commit is contained in:
hsilva664
2025-02-12 11:48:16 -03:00
committed by GitHub
parent a33ac830af
commit 281c0c8b5b
2 changed files with 131 additions and 2 deletions

View File

@@ -305,9 +305,9 @@ logger = logging.get_logger(__name__)
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCALER_NAME = "scaler.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
@@ -2394,6 +2394,7 @@ class Trainer:
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
self._load_scaler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
@@ -3191,6 +3192,7 @@ class Trainer:
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
self._save_scaler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
@@ -3424,6 +3426,47 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
def _save_scaler(self, output_dir):
# See if there is a scaler attribute
try:
scaler = self.accelerator.scaler
except AttributeError:
return
if scaler is None:
return
if is_torch_xla_available():
xm.rendezvous("saving_scaler_state")
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
reissue_pt_warnings(caught_warnings)
# Save SCALER
if self.args.should_save and not is_torch_xla_available():
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
reissue_pt_warnings(caught_warnings)
def _load_scaler(self, checkpoint):
"""If scaler state exists, load it."""
if checkpoint is None:
return
checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
if checkpoint_file_exists:
# On TPU we have to take some extra precautions to properly load the states on the right device.
# Load in scaler states
if is_torch_xla_available():
with warnings.catch_warnings(record=True) as caught_warnings:
scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(scaler_state, self.args.device)
self.accelerator.scaler.load_state_dict(scaler_state)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
reissue_pt_warnings(caught_warnings)
def _load_callback_state(self):
"""If callback states exist and were passed in, restore their states if enabled"""
if not self.args.restore_callback_states_from_checkpoint:

View File

@@ -46,6 +46,7 @@ from transformers import (
PretrainedConfig,
TrainerCallback,
TrainingArguments,
enable_full_determinism,
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
logging,
@@ -97,6 +98,7 @@ from transformers.testing_utils import (
require_torchdynamo,
require_vision,
require_wandb,
run_test_using_subprocess,
slow,
torch_device,
)
@@ -576,13 +578,41 @@ if is_torch_available():
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
def get_language_model_trainer(**kwargs):
import datasets
dataset = datasets.load_dataset("fka/awesome-chatgpt-prompts")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token
def _tokenize_function(examples):
model_inputs = tokenizer(examples["prompt"], padding="max_length", truncation=True)
model_inputs["labels"] = np.array(model_inputs["input_ids"]).astype(np.int64)
return model_inputs
tokenized_datasets = dataset.map(_tokenize_function, batched=True)
training_args = TrainingArguments(**kwargs)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
)
return trainer
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
def check_saved_checkpoints(
self, output_dir, freq, total, is_pretrained=True, safe_weights=True, use_scaler=False
):
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained:
file_list.append("config.json")
if use_scaler:
file_list.append("scaler.pt")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint))
@@ -3095,6 +3125,62 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
# require_torch_non_multi_accelerator is necessary because this worker blocks runs when using multiple GPUs, making
# the test slower.
@require_torch_non_multi_accelerator
@run_test_using_subprocess
@slow
def test_can_resume_training_lm(self):
# Check if it works for a simple language modeling example
training_steps = 10
resume_from_step = 8
with tempfile.TemporaryDirectory() as tmpdir:
enable_full_determinism(0)
kwargs = {
"output_dir": tmpdir,
"fp16": True,
"max_steps": training_steps,
"per_device_train_batch_size": 1,
"learning_rate": 1e-5,
"lr_scheduler_type": "cosine",
"save_strategy": "steps",
"save_steps": 1,
"logging_strategy": "steps",
"logging_steps": 1,
"report_to": "none",
}
trainer = get_language_model_trainer(**kwargs)
trainer.train(resume_from_checkpoint=False)
# Get the parameter length of the model
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
model_param_len = len(model_params)
# Sample uniform indexes and save the values of the parameters (considering an unrolled vector with
# all of them)
indices = torch.randint(0, model_param_len, (1000,))
# Save the values of the parameters for later comparison
model_params_sample = model_params[indices].detach().clone()
state1 = dataclasses.asdict(trainer.state)
# Delete the reference
del model_params, trainer
# Checks if all checkpoints are there, +1 is necessary because range is 1-indexed
self.check_saved_checkpoints(
tmpdir, freq=1, total=training_steps + 1, is_pretrained=True, safe_weights=True, use_scaler=True
)
# Checkpoint at intermediate step
enable_full_determinism(0)
checkpoint = os.path.join(tmpdir, f"checkpoint-{resume_from_step+1}")
trainer = get_language_model_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
# Check that the parameters are the same
self.assertTrue(torch.allclose(model_params[indices], model_params_sample))
state2 = dataclasses.asdict(trainer.state)
self.check_trainer_state_are_the_same(state1, state2)
del model_params, trainer
@unittest.skip(
reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`."
)