[MLU] Fix FA2 check error, remove deepspeed-mlu deps. (#36159)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn

* MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu

* Fix mlu FA2 check. Remove deepspeed-mlu check. add mlu tests support.

* fix testing errors.

* Merge branch 'hf/main' into main

* fix get_device_count error.

* fix mlu testing utils.

* fix code quality and style.

* switch to @require_torch_multi_accelerator
This commit is contained in:
huismiling
2025-03-31 17:02:49 +08:00
committed by GitHub
parent ad63d20dff
commit d0b65bb479
4 changed files with 63 additions and 18 deletions

View File

@@ -15,12 +15,12 @@
import argparse
from typing import Any, Callable
from transformers import is_torch_available
from transformers import is_torch_available, is_torch_mlu_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
require_torch_multi_accelerator,
)
@@ -46,7 +46,11 @@ if is_torch_available():
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
def wrapped(*args: Any, **kwargs: Any) -> Any:
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
torch.distributed.init_process_group(world_size=device_count)
try:
return func(*args, **kwargs)
finally:
@@ -56,7 +60,10 @@ if is_torch_available():
@manage_process_group
def fsdp_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
if is_torch_mlu_available():
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
else:
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
@@ -79,11 +86,14 @@ if is_torch_available():
@manage_process_group
def fsdp2_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
if is_torch_mlu_available():
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
else:
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
@@ -102,9 +112,13 @@ if is_torch_available():
class TestFSDPGeneration(TestCasePlus):
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_fsdp_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
distributed_args = f"""--nproc_per_node={device_count}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
@@ -113,9 +127,13 @@ class TestFSDPGeneration(TestCasePlus):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_fsdp2_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
distributed_args = f"""--nproc_per_node={device_count}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()