adding option to save/reload scaler (#34932)
* Adding option to save/reload scaler * Removing duplicate variable * Adding save/reload test * Small fixes on deterministic algorithm call * Moving LLM test to another file to isolate its environment * Moving back to old file and using subprocess to run test isolated * Reverting back accidental change * Reverting back accidental change
This commit is contained in:
@@ -305,9 +305,9 @@ logger = logging.get_logger(__name__)
|
|||||||
TRAINING_ARGS_NAME = "training_args.bin"
|
TRAINING_ARGS_NAME = "training_args.bin"
|
||||||
TRAINER_STATE_NAME = "trainer_state.json"
|
TRAINER_STATE_NAME = "trainer_state.json"
|
||||||
OPTIMIZER_NAME = "optimizer.pt"
|
OPTIMIZER_NAME = "optimizer.pt"
|
||||||
|
SCALER_NAME = "scaler.pt"
|
||||||
OPTIMIZER_NAME_BIN = "optimizer.bin"
|
OPTIMIZER_NAME_BIN = "optimizer.bin"
|
||||||
SCHEDULER_NAME = "scheduler.pt"
|
SCHEDULER_NAME = "scheduler.pt"
|
||||||
SCALER_NAME = "scaler.pt"
|
|
||||||
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
||||||
|
|
||||||
|
|
||||||
@@ -2394,6 +2394,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||||
|
self._load_scaler(resume_from_checkpoint)
|
||||||
|
|
||||||
# important: at this point:
|
# important: at this point:
|
||||||
# self.model is the Transformers Model
|
# self.model is the Transformers Model
|
||||||
@@ -3191,6 +3192,7 @@ class Trainer:
|
|||||||
if not self.args.save_only_model:
|
if not self.args.save_only_model:
|
||||||
# Save optimizer and scheduler
|
# Save optimizer and scheduler
|
||||||
self._save_optimizer_and_scheduler(output_dir)
|
self._save_optimizer_and_scheduler(output_dir)
|
||||||
|
self._save_scaler(output_dir)
|
||||||
# Save RNG state
|
# Save RNG state
|
||||||
self._save_rng_state(output_dir)
|
self._save_rng_state(output_dir)
|
||||||
|
|
||||||
@@ -3424,6 +3426,47 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
def _save_scaler(self, output_dir):
|
||||||
|
# See if there is a scaler attribute
|
||||||
|
try:
|
||||||
|
scaler = self.accelerator.scaler
|
||||||
|
except AttributeError:
|
||||||
|
return
|
||||||
|
if scaler is None:
|
||||||
|
return
|
||||||
|
if is_torch_xla_available():
|
||||||
|
xm.rendezvous("saving_scaler_state")
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
# Save SCALER
|
||||||
|
if self.args.should_save and not is_torch_xla_available():
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
def _load_scaler(self, checkpoint):
|
||||||
|
"""If scaler state exists, load it."""
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
|
||||||
|
|
||||||
|
if checkpoint_file_exists:
|
||||||
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||||
|
# Load in scaler states
|
||||||
|
if is_torch_xla_available():
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu")
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
xm.send_cpu_data_to_device(scaler_state, self.args.device)
|
||||||
|
self.accelerator.scaler.load_state_dict(scaler_state)
|
||||||
|
else:
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
def _load_callback_state(self):
|
def _load_callback_state(self):
|
||||||
"""If callback states exist and were passed in, restore their states if enabled"""
|
"""If callback states exist and were passed in, restore their states if enabled"""
|
||||||
if not self.args.restore_callback_states_from_checkpoint:
|
if not self.args.restore_callback_states_from_checkpoint:
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from transformers import (
|
|||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
enable_full_determinism,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -97,6 +98,7 @@ from transformers.testing_utils import (
|
|||||||
require_torchdynamo,
|
require_torchdynamo,
|
||||||
require_vision,
|
require_vision,
|
||||||
require_wandb,
|
require_wandb,
|
||||||
|
run_test_using_subprocess,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -576,13 +578,41 @@ if is_torch_available():
|
|||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_language_model_trainer(**kwargs):
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset("fka/awesome-chatgpt-prompts")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
def _tokenize_function(examples):
|
||||||
|
model_inputs = tokenizer(examples["prompt"], padding="max_length", truncation=True)
|
||||||
|
model_inputs["labels"] = np.array(model_inputs["input_ids"]).astype(np.int64)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
tokenized_datasets = dataset.map(_tokenize_function, batched=True)
|
||||||
|
training_args = TrainingArguments(**kwargs)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=tokenized_datasets["train"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
class TrainerIntegrationCommon:
|
class TrainerIntegrationCommon:
|
||||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
|
def check_saved_checkpoints(
|
||||||
|
self, output_dir, freq, total, is_pretrained=True, safe_weights=True, use_scaler=False
|
||||||
|
):
|
||||||
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
||||||
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||||
if is_pretrained:
|
if is_pretrained:
|
||||||
file_list.append("config.json")
|
file_list.append("config.json")
|
||||||
|
if use_scaler:
|
||||||
|
file_list.append("scaler.pt")
|
||||||
for step in range(freq, total, freq):
|
for step in range(freq, total, freq):
|
||||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||||
self.assertTrue(os.path.isdir(checkpoint))
|
self.assertTrue(os.path.isdir(checkpoint))
|
||||||
@@ -3095,6 +3125,62 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train(resume_from_checkpoint=True)
|
trainer.train(resume_from_checkpoint=True)
|
||||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||||
|
|
||||||
|
# require_torch_non_multi_accelerator is necessary because this worker blocks runs when using multiple GPUs, making
|
||||||
|
# the test slower.
|
||||||
|
@require_torch_non_multi_accelerator
|
||||||
|
@run_test_using_subprocess
|
||||||
|
@slow
|
||||||
|
def test_can_resume_training_lm(self):
|
||||||
|
# Check if it works for a simple language modeling example
|
||||||
|
training_steps = 10
|
||||||
|
resume_from_step = 8
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
enable_full_determinism(0)
|
||||||
|
kwargs = {
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
"fp16": True,
|
||||||
|
"max_steps": training_steps,
|
||||||
|
"per_device_train_batch_size": 1,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"lr_scheduler_type": "cosine",
|
||||||
|
"save_strategy": "steps",
|
||||||
|
"save_steps": 1,
|
||||||
|
"logging_strategy": "steps",
|
||||||
|
"logging_steps": 1,
|
||||||
|
"report_to": "none",
|
||||||
|
}
|
||||||
|
|
||||||
|
trainer = get_language_model_trainer(**kwargs)
|
||||||
|
trainer.train(resume_from_checkpoint=False)
|
||||||
|
# Get the parameter length of the model
|
||||||
|
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
|
||||||
|
model_param_len = len(model_params)
|
||||||
|
# Sample uniform indexes and save the values of the parameters (considering an unrolled vector with
|
||||||
|
# all of them)
|
||||||
|
indices = torch.randint(0, model_param_len, (1000,))
|
||||||
|
# Save the values of the parameters for later comparison
|
||||||
|
model_params_sample = model_params[indices].detach().clone()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
# Delete the reference
|
||||||
|
del model_params, trainer
|
||||||
|
# Checks if all checkpoints are there, +1 is necessary because range is 1-indexed
|
||||||
|
self.check_saved_checkpoints(
|
||||||
|
tmpdir, freq=1, total=training_steps + 1, is_pretrained=True, safe_weights=True, use_scaler=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkpoint at intermediate step
|
||||||
|
enable_full_determinism(0)
|
||||||
|
checkpoint = os.path.join(tmpdir, f"checkpoint-{resume_from_step+1}")
|
||||||
|
trainer = get_language_model_trainer(**kwargs)
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
model_params = torch.cat([p.cpu().flatten() for p in trainer.model.parameters()])
|
||||||
|
|
||||||
|
# Check that the parameters are the same
|
||||||
|
self.assertTrue(torch.allclose(model_params[indices], model_params_sample))
|
||||||
|
state2 = dataclasses.asdict(trainer.state)
|
||||||
|
self.check_trainer_state_are_the_same(state1, state2)
|
||||||
|
del model_params, trainer
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`."
|
reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user