exclude fsdp from delay_optimizer_creation (#34140)
* exclude fsdp from delay_optimizer_creation * add test case for trainer: FSDP mode and fp8 as mixed precision * rearrange imports * ruff formatted * adapt _init_fsdp to fp8 * use _init_fsdp only when resume_from_checkpoint * In case of FDP, self.layer will be CheckpointWrapper which has no len() method * delete _init_fsdp * solve conflict * fix conflict * make fixup
This commit is contained in:
committed by
GitHub
parent
92bcdff2ef
commit
8b3b9b48fc
@@ -20,6 +20,8 @@ from transformers.testing_utils import (
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_accelerate,
|
||||
require_fp8,
|
||||
require_fsdp,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
@@ -64,6 +66,7 @@ if is_torch_available():
|
||||
class TestFSDPTrainer(TestCasePlus):
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
@require_fsdp
|
||||
def test_trainer(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
cmd = [
|
||||
@@ -86,6 +89,35 @@ class TestFSDPTrainer(TestCasePlus):
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
class TestFSDPTrainerFP8(TestCasePlus):
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
@require_fsdp
|
||||
@require_fp8
|
||||
def test_trainer(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--use_fsdp",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"--num_processes",
|
||||
f"{torch.cuda.device_count()}",
|
||||
"--mixed_precision",
|
||||
"fp8",
|
||||
"--fsdp_transformer_layer_cls_to_wrap",
|
||||
"GPT2Block",
|
||||
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||
"--output_dir",
|
||||
f"{output_dir}",
|
||||
"--report_to",
|
||||
"none",
|
||||
]
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
Reference in New Issue
Block a user