Fix RNG reload in resume training from epoch checkpoint (#17055)
* Fix RNG reload in resume training from epoch checkpoint * Fix test
This commit is contained in:
@@ -789,12 +789,15 @@ class ModuleUtilsMixin:
|
|||||||
Returns:
|
Returns:
|
||||||
`int`: The total number of tokens.
|
`int`: The total number of tokens.
|
||||||
"""
|
"""
|
||||||
|
if not hasattr(self, "warnings_issued"):
|
||||||
|
self.warnings_issued = {}
|
||||||
if self.main_input_name in input_dict:
|
if self.main_input_name in input_dict:
|
||||||
return input_dict[self.main_input_name].numel()
|
return input_dict[self.main_input_name].numel()
|
||||||
else:
|
elif "estimate_tokens" not in self.warnings_issued:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
||||||
)
|
)
|
||||||
|
self.warnings_issued["estimate_tokens"] = True
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def floating_point_ops(
|
def floating_point_ops(
|
||||||
@@ -895,6 +898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Save config and origin of the pretrained weights if given in model
|
# Save config and origin of the pretrained weights if given in model
|
||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
|
self.warnings_issued = {}
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1151,7 +1151,8 @@ class Trainer:
|
|||||||
kwargs:
|
kwargs:
|
||||||
Additional keyword arguments used to hide deprecated arguments
|
Additional keyword arguments used to hide deprecated arguments
|
||||||
"""
|
"""
|
||||||
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
|
if resume_from_checkpoint is False:
|
||||||
|
resume_from_checkpoint = None
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
@@ -1395,6 +1396,9 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
||||||
|
|
||||||
|
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
|
||||||
|
self._load_rng_state(resume_from_checkpoint)
|
||||||
|
|
||||||
step = -1
|
step = -1
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ from transformers.testing_utils import (
|
|||||||
require_torch_bf16,
|
require_torch_bf16,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_non_multi_gpu,
|
|
||||||
require_torch_tf32,
|
require_torch_tf32,
|
||||||
require_torch_up_to_2_gpus,
|
require_torch_up_to_2_gpus,
|
||||||
require_wandb,
|
require_wandb,
|
||||||
@@ -162,11 +161,12 @@ class AlmostAccuracy:
|
|||||||
|
|
||||||
|
|
||||||
class RegressionModelConfig(PretrainedConfig):
|
class RegressionModelConfig(PretrainedConfig):
|
||||||
def __init__(self, a=0, b=0, double_output=False, **kwargs):
|
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
self.double_output = double_output
|
self.double_output = double_output
|
||||||
|
self.random_torch = random_torch
|
||||||
self.hidden_size = 1
|
self.hidden_size = 1
|
||||||
|
|
||||||
|
|
||||||
@@ -264,14 +264,18 @@ if is_torch_available():
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.a = nn.Parameter(torch.tensor(config.a).float())
|
self.a = nn.Parameter(torch.tensor(config.a).float())
|
||||||
self.b = nn.Parameter(torch.tensor(config.b).float())
|
self.b = nn.Parameter(torch.tensor(config.b).float())
|
||||||
|
self.random_torch = config.random_torch
|
||||||
|
|
||||||
def forward(self, input_x, labels=None, **kwargs):
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
y = input_x * self.a + self.b
|
y = input_x * self.a + self.b
|
||||||
|
if self.random_torch:
|
||||||
torch_rand = torch.randn(1).squeeze()
|
torch_rand = torch.randn(1).squeeze()
|
||||||
np_rand = np.random.rand()
|
np_rand = np.random.rand()
|
||||||
rand_rand = random.random()
|
rand_rand = random.random()
|
||||||
|
|
||||||
y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand)
|
if self.random_torch:
|
||||||
|
y += 0.05 * torch_rand
|
||||||
|
y += 0.05 * torch.tensor(np_rand + rand_rand)
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
return (y,)
|
return (y,)
|
||||||
@@ -1016,17 +1020,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train(resume_from_checkpoint=True)
|
trainer.train(resume_from_checkpoint=True)
|
||||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||||
|
|
||||||
@require_torch_non_multi_gpu
|
|
||||||
def test_resume_training_with_randomness(self):
|
def test_resume_training_with_randomness(self):
|
||||||
# This test will fail flakily for more than 1 GPUs since the result will be slightly more different
|
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
|
||||||
# TODO: investigate why it fails for 2 GPUs?
|
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
|
||||||
|
# GPU 0 will call first and sometimes GPU 1).
|
||||||
|
random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
eval_dataset = RegressionDataset()
|
eval_dataset = RegressionDataset()
|
||||||
|
|
||||||
config = RegressionModelConfig(a=0, b=2)
|
with self.subTest("Test every step"):
|
||||||
|
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
|
||||||
model = RegressionRandomPreTrainedModel(config)
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
@@ -1044,6 +1050,31 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertAlmostEqual(a, a1, delta=1e-8)
|
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||||
|
|
||||||
|
with self.subTest("Test every epoch"):
|
||||||
|
config = RegressionModelConfig(a=0, b=2, random_torch=random_torch)
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args = RegressionTrainingArguments(tmp_dir, save_strategy="epoch", learning_rate=0.1)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
checkpoints = [d for d in os.listdir(tmp_dir) if d.startswith("checkpoint-")]
|
||||||
|
# There should be one checkpoint per epoch.
|
||||||
|
self.assertEqual(len(checkpoints), 3)
|
||||||
|
checkpoint_dir = sorted(checkpoints, key=lambda x: int(x.replace("checkpoint-", "")))[0]
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, checkpoint_dir))
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
|
||||||
|
self.assertAlmostEqual(a, a1, delta=1e-8)
|
||||||
|
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||||
|
|
||||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||||
def test_training_with_resume_from_checkpoint_false(self):
|
def test_training_with_resume_from_checkpoint_false(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
|||||||
Reference in New Issue
Block a user