device agnostic fsdp testing (#27120)
* make fsdp test cases device agnostic * make style
This commit is contained in:
@@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_gpu_count,
|
|
||||||
mockenv_context,
|
mockenv_context,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_fsdp,
|
require_fsdp,
|
||||||
require_torch_gpu,
|
require_torch_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_accelerator,
|
||||||
slow,
|
slow,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import FSDPOption, set_seed
|
from transformers.trainer_utils import FSDPOption, set_seed
|
||||||
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
|
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -46,7 +47,7 @@ else:
|
|||||||
# default torch.distributed port
|
# default torch.distributed port
|
||||||
DEFAULT_MASTER_PORT = "10999"
|
DEFAULT_MASTER_PORT = "10999"
|
||||||
dtypes = ["fp16"]
|
dtypes = ["fp16"]
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_available_on_device(torch_device):
|
||||||
dtypes += ["bf16"]
|
dtypes += ["bf16"]
|
||||||
sharding_strategies = ["full_shard", "shard_grad_op"]
|
sharding_strategies = ["full_shard", "shard_grad_op"]
|
||||||
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
|
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||||
@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
|
|||||||
# - it won't be able to handle that
|
# - it won't be able to handle that
|
||||||
# 2. for now testing with just 2 gpus max (since some quality tests may give different
|
# 2. for now testing with just 2 gpus max (since some quality tests may give different
|
||||||
# results with mode gpus because we use very little data)
|
# results with mode gpus because we use very little data)
|
||||||
num_gpus = min(2, get_gpu_count()) if distributed else 1
|
num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
|
||||||
master_port = get_master_port(real_launcher=True)
|
master_port = get_master_port(real_launcher=True)
|
||||||
if use_accelerate:
|
if use_accelerate:
|
||||||
return f"""accelerate launch
|
return f"""accelerate launch
|
||||||
@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
|
|||||||
|
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
@require_fsdp_version
|
@require_fsdp_version
|
||||||
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
@slow
|
@slow
|
||||||
def test_basic_run(self, sharding_strategy, dtype):
|
def test_basic_run(self, sharding_strategy, dtype):
|
||||||
launcher = get_launcher(distributed=True, use_accelerate=False)
|
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||||
@@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
@parameterized.expand(dtypes)
|
@parameterized.expand(dtypes)
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
@slow
|
@slow
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
||||||
def test_basic_run_with_cpu_offload(self, dtype):
|
def test_basic_run_with_cpu_offload(self, dtype):
|
||||||
@@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
@slow
|
@slow
|
||||||
def test_training_and_can_resume_normally(self, state_dict_type):
|
def test_training_and_can_resume_normally(self, state_dict_type):
|
||||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user