From 8b3b9b48fcd6bc06bd9c576f1b09266d577db257 Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Mon, 28 Oct 2024 13:50:16 +0100 Subject: [PATCH] exclude fsdp from delay_optimizer_creation (#34140) * exclude fsdp from delay_optimizer_creation * add test case for trainer: FSDP mode and fp8 as mixed precision * rearrange imports * ruff formatted * adapt _init_fsdp to fp8 * use _init_fsdp only when resume_from_checkpoint * In case of FDP, self.layer will be CheckpointWrapper which has no len() method * delete _init_fsdp * solve conflict * fix conflict * make fixup --- src/transformers/testing_utils.py | 8 ++++++++ src/transformers/trainer.py | 7 +++++-- tests/trainer/test_trainer_fsdp.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7bb2d5049d..2781e9e102 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -144,6 +144,7 @@ from .utils import ( if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState + from accelerate.utils.imports import is_fp8_available if is_pytest_available(): @@ -1000,6 +1001,13 @@ def require_torch_fp16(test_case): )(test_case) +def require_fp8(test_case): + """Decorator marking a test that requires supports for fp8""" + return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")( + test_case + ) + + def require_torch_bf16(test_case): """Decorator marking a test that requires a device that supports bf16""" return unittest.skipUnless( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8fe25b7466..64cb5c6bd4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2209,7 +2209,7 @@ class Trainer: else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -2258,9 +2258,12 @@ class Trainer: # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False + # configure fsdp plugin for qlora if any + if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() + if delay_optimizer_creation: if use_accelerator_prepare: - self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py index 994a82a8db..4bcf5de045 100644 --- a/tests/trainer/test_trainer_fsdp.py +++ b/tests/trainer/test_trainer_fsdp.py @@ -20,6 +20,8 @@ from transformers.testing_utils import ( execute_subprocess_async, get_torch_dist_unique_port, require_accelerate, + require_fp8, + require_fsdp, require_torch_multi_gpu, ) @@ -64,6 +66,7 @@ if is_torch_available(): class TestFSDPTrainer(TestCasePlus): @require_accelerate @require_torch_multi_gpu + @require_fsdp def test_trainer(self): output_dir = self.get_auto_remove_tmp_dir() cmd = [ @@ -86,6 +89,35 @@ class TestFSDPTrainer(TestCasePlus): # successful return here == success - any errors would have caused an error in the sub-call +class TestFSDPTrainerFP8(TestCasePlus): + @require_accelerate + @require_torch_multi_gpu + @require_fsdp + @require_fp8 + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{torch.cuda.device_count()}", + "--mixed_precision", + "fp8", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + if __name__ == "__main__": parser = HfArgumentParser((Seq2SeqTrainingArguments,)) training_args = parser.parse_args_into_dataclasses()[0]