Reproducible checkpoint (#11582)
* Set generator in dataloader * Use generator in all random samplers * Checkpoint all RNG states * Final version * Quality * Test * Address review comments * Quality * Remove debug util * Add python and numpy RNGs * Split states in different files in distributed * Quality * local_rank for TPUs * Only use generator when accepted * Add test * Set seed to avoid flakiness * Make test less flaky * Quality
This commit is contained in:
@@ -204,7 +204,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
run_ner.main()
|
run_ner.main()
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
|
||||||
self.assertLess(result["eval_loss"], 0.5)
|
self.assertLess(result["eval_loss"], 0.5)
|
||||||
|
|
||||||
def test_run_squad(self):
|
def test_run_squad(self):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import collections
|
|||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
@@ -127,6 +128,7 @@ from .utils import logging
|
|||||||
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
_is_torch_generator_available = False
|
||||||
_is_native_amp_available = False
|
_is_native_amp_available = False
|
||||||
|
|
||||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||||
@@ -141,6 +143,7 @@ if is_apex_available():
|
|||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||||
|
_is_torch_generator_available = True
|
||||||
_is_native_amp_available = True
|
_is_native_amp_available = True
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
@@ -525,6 +528,11 @@ class Trainer:
|
|||||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
generator = None
|
||||||
|
if self.args.world_size <= 1 and _is_torch_generator_available:
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
|
||||||
|
|
||||||
# Build the sampler.
|
# Build the sampler.
|
||||||
if self.args.group_by_length:
|
if self.args.group_by_length:
|
||||||
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
|
||||||
@@ -538,7 +546,11 @@ class Trainer:
|
|||||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return LengthGroupedSampler(
|
return LengthGroupedSampler(
|
||||||
self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
|
self.train_dataset,
|
||||||
|
self.args.train_batch_size,
|
||||||
|
lengths=lengths,
|
||||||
|
model_input_name=model_input_name,
|
||||||
|
generator=generator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return DistributedLengthGroupedSampler(
|
return DistributedLengthGroupedSampler(
|
||||||
@@ -553,6 +565,8 @@ class Trainer:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
|
if _is_torch_generator_available:
|
||||||
|
return RandomSampler(self.train_dataset, generator=generator)
|
||||||
return RandomSampler(self.train_dataset)
|
return RandomSampler(self.train_dataset)
|
||||||
elif (
|
elif (
|
||||||
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
||||||
@@ -1224,6 +1238,8 @@ class Trainer:
|
|||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
if steps_trained_progress_bar is not None:
|
if steps_trained_progress_bar is not None:
|
||||||
steps_trained_progress_bar.update(1)
|
steps_trained_progress_bar.update(1)
|
||||||
|
if steps_trained_in_current_epoch == 0:
|
||||||
|
self._load_rng_state(resume_from_checkpoint)
|
||||||
continue
|
continue
|
||||||
elif steps_trained_progress_bar is not None:
|
elif steps_trained_progress_bar is not None:
|
||||||
steps_trained_progress_bar.close()
|
steps_trained_progress_bar.close()
|
||||||
@@ -1381,6 +1397,41 @@ class Trainer:
|
|||||||
self._save_checkpoint(model, trial, metrics=metrics)
|
self._save_checkpoint(model, trial, metrics=metrics)
|
||||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
def _load_rng_state(self, checkpoint):
|
||||||
|
# Load RNG states from `checkpoint`
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||||
|
if local_rank != -1:
|
||||||
|
rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
|
||||||
|
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
|
||||||
|
logger.info(
|
||||||
|
f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
|
||||||
|
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||||
|
if not os.path.isfile(os.path.join(checkpoint, rng_file)):
|
||||||
|
logger.info(
|
||||||
|
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||||
|
"fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint_rng_state = torch.load(rng_file)
|
||||||
|
random.setstate(checkpoint_rng_state["python"])
|
||||||
|
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||||
|
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if self.args.local_rank != -1:
|
||||||
|
torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
|
||||||
|
else:
|
||||||
|
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
xm.set_rng_state(checkpoint_rng_state["xla"])
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
# want to save except FullyShardedDDP.
|
# want to save except FullyShardedDDP.
|
||||||
@@ -1460,6 +1511,28 @@ class Trainer:
|
|||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||||
|
|
||||||
|
# Save RNG state in non-distributed training
|
||||||
|
rng_states = {
|
||||||
|
"python": random.getstate(),
|
||||||
|
"numpy": np.random.get_state(),
|
||||||
|
"cpu": torch.random.get_rng_state(),
|
||||||
|
}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if self.args.local_rank == -1:
|
||||||
|
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||||
|
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||||
|
else:
|
||||||
|
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
rng_states["xla"] = xm.get_rng_state()
|
||||||
|
|
||||||
|
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||||
|
if local_rank == -1:
|
||||||
|
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||||
|
else:
|
||||||
|
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||||
|
|
||||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||||
"""If optimizer and scheduler states exist, load them."""
|
"""If optimizer and scheduler states exist, load them."""
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
|
|||||||
@@ -510,6 +510,7 @@ class LengthGroupedSampler(Sampler):
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
lengths: Optional[List[int]] = None,
|
lengths: Optional[List[int]] = None,
|
||||||
model_input_name: Optional[str] = None,
|
model_input_name: Optional[str] = None,
|
||||||
|
generator=None,
|
||||||
):
|
):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -525,12 +526,13 @@ class LengthGroupedSampler(Sampler):
|
|||||||
)
|
)
|
||||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
||||||
self.lengths = lengths
|
self.lengths = lengths
|
||||||
|
self.generator = generator
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.lengths)
|
return len(self.lengths)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
indices = get_length_grouped_indices(self.lengths, self.batch_size)
|
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,9 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import gc
|
import gc
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -195,6 +197,28 @@ if is_torch_available():
|
|||||||
loss = torch.nn.functional.mse_loss(y, labels)
|
loss = torch.nn.functional.mse_loss(y, labels)
|
||||||
return (loss, y, y) if self.double_output else (loss, y)
|
return (loss, y, y) if self.double_output else (loss, y)
|
||||||
|
|
||||||
|
class RegressionRandomPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = RegressionModelConfig
|
||||||
|
base_model_prefix = "regression"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
|
||||||
|
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
||||||
|
|
||||||
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
|
y = input_x * self.a + self.b
|
||||||
|
torch_rand = torch.randn(1).squeeze()
|
||||||
|
np_rand = np.random.rand()
|
||||||
|
rand_rand = random.random()
|
||||||
|
|
||||||
|
y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand)
|
||||||
|
|
||||||
|
if labels is None:
|
||||||
|
return (y,)
|
||||||
|
loss = torch.nn.functional.mse_loss(y, labels)
|
||||||
|
return (loss, y)
|
||||||
|
|
||||||
class TstLayer(torch.nn.Module):
|
class TstLayer(torch.nn.Module):
|
||||||
def __init__(self, hidden_size):
|
def __init__(self, hidden_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -699,6 +723,34 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train(resume_from_checkpoint=True)
|
trainer.train(resume_from_checkpoint=True)
|
||||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||||
|
|
||||||
|
def test_resume_training_with_randomness(self):
|
||||||
|
if torch.cuda.device_count() >= 2:
|
||||||
|
# This test will fail flakily for more than 2 GPUs since the result will be slightly more different.
|
||||||
|
return
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
eval_dataset = RegressionDataset()
|
||||||
|
|
||||||
|
config = RegressionModelConfig(a=0, b=2)
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
|
||||||
|
self.assertTrue(math.isclose(a, a1, rel_tol=1e-8))
|
||||||
|
self.assertTrue(math.isclose(b, b1, rel_tol=1e-8))
|
||||||
|
|
||||||
def test_resume_training_with_gradient_accumulation(self):
|
def test_resume_training_with_gradient_accumulation(self):
|
||||||
if torch.cuda.device_count() > 2:
|
if torch.cuda.device_count() > 2:
|
||||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
|||||||
Reference in New Issue
Block a user