Revert "Fix FSDP resume Initialization issue" (#34193)
Revert "Fix FSDP resume Initialization issue (#34032)"
This reverts commit 4de1bdbf63.
This commit is contained in:
@@ -273,39 +273,6 @@ def _get_fsdp_ckpt_kwargs():
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _init_fsdp(model, accelerator, device):
|
|
||||||
"""
|
|
||||||
Initialize Fully Sharded Data Parallel (FSDP) for the model.
|
|
||||||
|
|
||||||
This function is needed to properly initialize FSDP when resuming from a checkpoint.
|
|
||||||
It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
|
|
||||||
See https://github.com/huggingface/transformers/issues/31892 for more details.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to initialize with FSDP.
|
|
||||||
accelerator: The Accelerator object.
|
|
||||||
device: The device to run the model on.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The initialized FSDP model.
|
|
||||||
"""
|
|
||||||
model = accelerator.prepare(model)
|
|
||||||
model.train()
|
|
||||||
with torch.no_grad():
|
|
||||||
# Run a forward pass with dummy inputs to initialize FSDP
|
|
||||||
dummy_input = {
|
|
||||||
name: torch.ones(
|
|
||||||
(1, 512),
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
for name in model.forward.__code__.co_varnames
|
|
||||||
if name != "self"
|
|
||||||
}
|
|
||||||
_ = model(**dummy_input)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
@@ -634,10 +601,6 @@ class Trainer:
|
|||||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
|
||||||
self.model = _init_fsdp(self.model, self.accelerator, self.args.device)
|
|
||||||
|
|
||||||
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
|
if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
|
||||||
self.optimizer is not None or self.lr_scheduler is not None
|
self.optimizer is not None or self.lr_scheduler is not None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -4914,34 +4914,3 @@ class OptimizerAndModelInspectionTest(unittest.TestCase):
|
|||||||
param = next(model.parameters())
|
param = next(model.parameters())
|
||||||
group = trainer.get_optimizer_group(param)
|
group = trainer.get_optimizer_group(param)
|
||||||
self.assertIn(param, group["params"])
|
self.assertIn(param, group["params"])
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_torch
|
|
||||||
@require_accelerate
|
|
||||||
class TestFSDPInitialization(unittest.TestCase):
|
|
||||||
def test_fsdp_initialization(self):
|
|
||||||
config = RegressionModelConfig(a=1, b=1, double_output=False)
|
|
||||||
model = RegressionPreTrainedModel(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
training_args = TrainingArguments(
|
|
||||||
output_dir=tmp_dir,
|
|
||||||
fsdp=True,
|
|
||||||
fsdp_config={"min_num_params": 1},
|
|
||||||
no_cuda=True,
|
|
||||||
)
|
|
||||||
trainer = Trainer(model=model, args=training_args)
|
|
||||||
|
|
||||||
# Check for FSDP enabled
|
|
||||||
self.assertTrue(trainer.is_fsdp_enabled)
|
|
||||||
|
|
||||||
# Check if model is wrapped with FSDP
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
|
|
||||||
self.assertTrue(trainer.model, FSDP)
|
|
||||||
|
|
||||||
# Running a forward pass to ensure FSDP is initialized
|
|
||||||
dummy_input = torch.ones((1, 1), dtype=torch.float)
|
|
||||||
output = trainer.model(dummy_input)
|
|
||||||
self.assertTrue(output)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user