From 4de1bdbf637fe6411c104c62ab385f660bfb1064 Mon Sep 17 00:00:00 2001 From: Shikhar Mishra <77426122+Itssshikhar@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:18:10 +0530 Subject: [PATCH] Fix FSDP resume Initialization issue (#34032) * Fix FSDP Initialization for resume training * Added init_fsdp function to work with dummy values * Fix FSDP initialization for resuming training * Added CUDA decorator for tests * Added torch_gpu decorator to FSDP tests * Fixup for failing code quality tests --- src/transformers/trainer.py | 37 +++++++++++++++++++++++++++++++++++ tests/trainer/test_trainer.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 20b9f6dad2..5131676c95 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -273,6 +273,39 @@ def _get_fsdp_ckpt_kwargs(): return {} +def _init_fsdp(model, accelerator, device): + """ + Initialize Fully Sharded Data Parallel (FSDP) for the model. + + This function is needed to properly initialize FSDP when resuming from a checkpoint. + It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. + See https://github.com/huggingface/transformers/issues/31892 for more details. + + Args: + model: The model to initialize with FSDP. + accelerator: The Accelerator object. + device: The device to run the model on. + + Returns: + The initialized FSDP model. + """ + model = accelerator.prepare(model) + model.train() + with torch.no_grad(): + # Run a forward pass with dummy inputs to initialize FSDP + dummy_input = { + name: torch.ones( + (1, 512), + dtype=torch.long, + device=device, + ) + for name in model.forward.__code__.co_varnames + if name != "self" + } + _ = model(**dummy_input) + return model + + if TYPE_CHECKING: import optuna @@ -601,6 +634,10 @@ class Trainer: " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) + + if self.is_fsdp_enabled: + self.model = _init_fsdp(self.model, self.accelerator, self.args.device) + if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( self.optimizer is not None or self.lr_scheduler is not None ): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cbc93faf50..8feb5d92e8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4914,3 +4914,34 @@ class OptimizerAndModelInspectionTest(unittest.TestCase): param = next(model.parameters()) group = trainer.get_optimizer_group(param) self.assertIn(param, group["params"]) + + +@require_torch_gpu +@require_torch +@require_accelerate +class TestFSDPInitialization(unittest.TestCase): + def test_fsdp_initialization(self): + config = RegressionModelConfig(a=1, b=1, double_output=False) + model = RegressionPreTrainedModel(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + fsdp=True, + fsdp_config={"min_num_params": 1}, + no_cuda=True, + ) + trainer = Trainer(model=model, args=training_args) + + # Check for FSDP enabled + self.assertTrue(trainer.is_fsdp_enabled) + + # Check if model is wrapped with FSDP + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + self.assertTrue(trainer.model, FSDP) + + # Running a forward pass to ensure FSDP is initialized + dummy_input = torch.ones((1, 1), dtype=torch.float) + output = trainer.model(dummy_input) + self.assertTrue(output)