exclude fsdp from delay_optimizer_creation (#34140)
* exclude fsdp from delay_optimizer_creation * add test case for trainer: FSDP mode and fp8 as mixed precision * rearrange imports * ruff formatted * adapt _init_fsdp to fp8 * use _init_fsdp only when resume_from_checkpoint * In case of FDP, self.layer will be CheckpointWrapper which has no len() method * delete _init_fsdp * solve conflict * fix conflict * make fixup
This commit is contained in:
committed by
GitHub
parent
92bcdff2ef
commit
8b3b9b48fc
@@ -144,6 +144,7 @@ from .utils import (
|
|||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate.state import AcceleratorState, PartialState
|
from accelerate.state import AcceleratorState, PartialState
|
||||||
|
from accelerate.utils.imports import is_fp8_available
|
||||||
|
|
||||||
|
|
||||||
if is_pytest_available():
|
if is_pytest_available():
|
||||||
@@ -1000,6 +1001,13 @@ def require_torch_fp16(test_case):
|
|||||||
)(test_case)
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_fp8(test_case):
|
||||||
|
"""Decorator marking a test that requires supports for fp8"""
|
||||||
|
return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
|
||||||
|
test_case
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_bf16(test_case):
|
def require_torch_bf16(test_case):
|
||||||
"""Decorator marking a test that requires a device that supports bf16"""
|
"""Decorator marking a test that requires a device that supports bf16"""
|
||||||
return unittest.skipUnless(
|
return unittest.skipUnless(
|
||||||
|
|||||||
@@ -2209,7 +2209,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||||
|
|
||||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||||
|
|
||||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||||
if self._created_lr_scheduler:
|
if self._created_lr_scheduler:
|
||||||
@@ -2258,9 +2258,12 @@ class Trainer:
|
|||||||
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||||
use_accelerator_prepare = True if model is self.model else False
|
use_accelerator_prepare = True if model is self.model else False
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
# configure fsdp plugin for qlora if any
|
||||||
if use_accelerator_prepare:
|
if use_accelerator_prepare:
|
||||||
self._fsdp_qlora_plugin_updates()
|
self._fsdp_qlora_plugin_updates()
|
||||||
|
|
||||||
|
if delay_optimizer_creation:
|
||||||
|
if use_accelerator_prepare:
|
||||||
self.model = self.accelerator.prepare(self.model)
|
self.model = self.accelerator.prepare(self.model)
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from transformers.testing_utils import (
|
|||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_fp8,
|
||||||
|
require_fsdp,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,6 +66,7 @@ if is_torch_available():
|
|||||||
class TestFSDPTrainer(TestCasePlus):
|
class TestFSDPTrainer(TestCasePlus):
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
|
@require_fsdp
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
cmd = [
|
cmd = [
|
||||||
@@ -86,6 +89,35 @@ class TestFSDPTrainer(TestCasePlus):
|
|||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
|
|
||||||
|
class TestFSDPTrainerFP8(TestCasePlus):
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_fsdp
|
||||||
|
@require_fp8
|
||||||
|
def test_trainer(self):
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
cmd = [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--use_fsdp",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
"--num_processes",
|
||||||
|
f"{torch.cuda.device_count()}",
|
||||||
|
"--mixed_precision",
|
||||||
|
"fp8",
|
||||||
|
"--fsdp_transformer_layer_cls_to_wrap",
|
||||||
|
"GPT2Block",
|
||||||
|
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||||
|
"--output_dir",
|
||||||
|
f"{output_dir}",
|
||||||
|
"--report_to",
|
||||||
|
"none",
|
||||||
|
]
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
||||||
training_args = parser.parse_args_into_dataclasses()[0]
|
training_args = parser.parse_args_into_dataclasses()[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user