adding option to save/reload scaler (#34932)
* Adding option to save/reload scaler * Removing duplicate variable * Adding save/reload test * Small fixes on deterministic algorithm call * Moving LLM test to another file to isolate its environment * Moving back to old file and using subprocess to run test isolated * Reverting back accidental change * Reverting back accidental change
This commit is contained in:
@@ -305,9 +305,9 @@ logger = logging.get_logger(__name__)
|
||||
TRAINING_ARGS_NAME = "training_args.bin"
|
||||
TRAINER_STATE_NAME = "trainer_state.json"
|
||||
OPTIMIZER_NAME = "optimizer.pt"
|
||||
SCALER_NAME = "scaler.pt"
|
||||
OPTIMIZER_NAME_BIN = "optimizer.bin"
|
||||
SCHEDULER_NAME = "scheduler.pt"
|
||||
SCALER_NAME = "scaler.pt"
|
||||
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
||||
|
||||
|
||||
@@ -2394,6 +2394,7 @@ class Trainer:
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
self._load_scaler(resume_from_checkpoint)
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
@@ -3191,6 +3192,7 @@ class Trainer:
|
||||
if not self.args.save_only_model:
|
||||
# Save optimizer and scheduler
|
||||
self._save_optimizer_and_scheduler(output_dir)
|
||||
self._save_scaler(output_dir)
|
||||
# Save RNG state
|
||||
self._save_rng_state(output_dir)
|
||||
|
||||
@@ -3424,6 +3426,47 @@ class Trainer:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _save_scaler(self, output_dir):
|
||||
# See if there is a scaler attribute
|
||||
try:
|
||||
scaler = self.accelerator.scaler
|
||||
except AttributeError:
|
||||
return
|
||||
if scaler is None:
|
||||
return
|
||||
if is_torch_xla_available():
|
||||
xm.rendezvous("saving_scaler_state")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
# Save SCALER
|
||||
if self.args.should_save and not is_torch_xla_available():
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _load_scaler(self, checkpoint):
|
||||
"""If scaler state exists, load it."""
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
|
||||
|
||||
if checkpoint_file_exists:
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
# Load in scaler states
|
||||
if is_torch_xla_available():
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
xm.send_cpu_data_to_device(scaler_state, self.args.device)
|
||||
self.accelerator.scaler.load_state_dict(scaler_state)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _load_callback_state(self):
|
||||
"""If callback states exist and were passed in, restore their states if enabled"""
|
||||
if not self.args.restore_callback_states_from_checkpoint:
|
||||
|
||||
@@ -46,6 +46,7 @@ from transformers import (
|
||||
PretrainedConfig,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
enable_full_determinism,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
is_torch_available,
|
||||
logging,
|
||||
@@ -97,6 +98,7 @@ from transformers.testing_utils import (
|
||||
require_torchdynamo,
|
||||
require_vision,
|
||||
require_wandb,
|
||||
run_test_using_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -576,13 +578,41 @@ if is_torch_available():
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
def get_language_model_trainer(**kwargs):
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset("fka/awesome-chatgpt-prompts")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def _tokenize_function(examples):
|
||||
model_inputs = tokenizer(examples["prompt"], padding="max_length", truncation=True)
|
||||
model_inputs["labels"] = np.array(model_inputs["input_ids"]).astype(np.int64)
|
||||
return model_inputs
|
||||
|
||||
tokenized_datasets = dataset.map(_tokenize_function, batched=True)
|
||||
training_args = TrainingArguments(**kwargs)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class TrainerIntegrationCommon:
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
|
||||
def check_saved_checkpoints(
|
||||
self, output_dir, freq, total, is_pretrained=True, safe_weights=True, use_scaler=False
|
||||
):
|
||||
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
||||
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
file_list.append("config.json")
|
||||
if use_scaler:
|
||||
file_list.append("scaler.pt")
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
self.assertTrue(os.path.isdir(checkpoint))
|
||||
@@ -3095,6 +3125,62 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
# require_torch_non_multi_accelerator is necessary because this worker blocks runs when using multiple GPUs, making
|
||||
# the test slower.
|
||||
@require_torch_non_multi_accelerator
|
||||
@run_test_using_subprocess
|
||||
@slow
|
||||
def test_can_resume_training_lm(self):
|
||||
# Check if it works for a simple language modeling example
|
||||
training_steps = 10
|
||||
resume_from_step = 8
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
enable_full_determinism(0)
|
||||
kwargs = {
|
||||
"output_dir": tmpdir,
|
||||
"fp16": True,
|
||||
"max_steps": training_steps,
|
||||
"per_device_train_batch_size": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"lr_scheduler_type": "cosine",
|
||||
"save_strategy": "steps",
|
||||
"save_steps": 1,
|
||||
"logging_strategy": "steps",
|
||||
"logging_steps": 1,
|
||||
"report_to": "none",
|
||||
}
|
||||
|
||||
trainer = get_language_model_trainer(**kwargs)
|
||||
trainer.train(resume_from_checkpoint=False)
|
||||
# Get the parameter length of the model
|
||||
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
|
||||
model_param_len = len(model_params)
|
||||
# Sample uniform indexes and save the values of the parameters (considering an unrolled vector with
|
||||
# all of them)
|
||||
indices = torch.randint(0, model_param_len, (1000,))
|
||||
# Save the values of the parameters for later comparison
|
||||
model_params_sample = model_params[indices].detach().clone()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
# Delete the reference
|
||||
del model_params, trainer
|
||||
# Checks if all checkpoints are there, +1 is necessary because range is 1-indexed
|
||||
self.check_saved_checkpoints(
|
||||
tmpdir, freq=1, total=training_steps + 1, is_pretrained=True, safe_weights=True, use_scaler=True
|
||||
)
|
||||
|
||||
# Checkpoint at intermediate step
|
||||
enable_full_determinism(0)
|
||||
checkpoint = os.path.join(tmpdir, f"checkpoint-{resume_from_step+1}")
|
||||
trainer = get_language_model_trainer(**kwargs)
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
|
||||
|
||||
# Check that the parameters are the same
|
||||
self.assertTrue(torch.allclose(model_params[indices], model_params_sample))
|
||||
state2 = dataclasses.asdict(trainer.state)
|
||||
self.check_trainer_state_are_the_same(state1, state2)
|
||||
del model_params, trainer
|
||||
|
||||
@unittest.skip(
|
||||
reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user