From 3f06f95ebe617b192251ef756518690f5bc7ff76 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 16 Oct 2024 21:25:18 +0200 Subject: [PATCH] Revert "Fix FSDP resume Initialization issue" (#34193) Revert "Fix FSDP resume Initialization issue (#34032)" This reverts commit 4de1bdbf637fe6411c104c62ab385f660bfb1064. --- src/transformers/trainer.py | 37 ----------------------------------- tests/trainer/test_trainer.py | 31 ----------------------------- 2 files changed, 68 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5131676c95..20b9f6dad2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -273,39 +273,6 @@ def _get_fsdp_ckpt_kwargs(): return {} -def _init_fsdp(model, accelerator, device): - """ - Initialize Fully Sharded Data Parallel (FSDP) for the model. - - This function is needed to properly initialize FSDP when resuming from a checkpoint. - It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. - See https://github.com/huggingface/transformers/issues/31892 for more details. - - Args: - model: The model to initialize with FSDP. - accelerator: The Accelerator object. - device: The device to run the model on. - - Returns: - The initialized FSDP model. - """ - model = accelerator.prepare(model) - model.train() - with torch.no_grad(): - # Run a forward pass with dummy inputs to initialize FSDP - dummy_input = { - name: torch.ones( - (1, 512), - dtype=torch.long, - device=device, - ) - for name in model.forward.__code__.co_varnames - if name != "self" - } - _ = model(**dummy_input) - return model - - if TYPE_CHECKING: import optuna @@ -634,10 +601,6 @@ class Trainer: " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - - if self.is_fsdp_enabled: - self.model = _init_fsdp(self.model, self.accelerator, self.args.device) - if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( self.optimizer is not None or self.lr_scheduler is not None ): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8feb5d92e8..cbc93faf50 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4914,34 +4914,3 @@ class OptimizerAndModelInspectionTest(unittest.TestCase): param = next(model.parameters()) group = trainer.get_optimizer_group(param) self.assertIn(param, group["params"]) - - -@require_torch_gpu -@require_torch -@require_accelerate -class TestFSDPInitialization(unittest.TestCase): - def test_fsdp_initialization(self): - config = RegressionModelConfig(a=1, b=1, double_output=False) - model = RegressionPreTrainedModel(config) - - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( - output_dir=tmp_dir, - fsdp=True, - fsdp_config={"min_num_params": 1}, - no_cuda=True, - ) - trainer = Trainer(model=model, args=training_args) - - # Check for FSDP enabled - self.assertTrue(trainer.is_fsdp_enabled) - - # Check if model is wrapped with FSDP - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - self.assertTrue(trainer.model, FSDP) - - # Running a forward pass to ensure FSDP is initialized - dummy_input = torch.ones((1, 1), dtype=torch.float) - output = trainer.model(dummy_input) - self.assertTrue(output)