[MLU] Fix FA2 check error, remove deepspeed-mlu deps. (#36159)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn

* MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu

* Fix mlu FA2 check. Remove deepspeed-mlu check. add mlu tests support.

* fix testing errors.

* Merge branch 'hf/main' into main

* fix get_device_count error.

* fix mlu testing utils.

* fix code quality and style.

* switch to @require_torch_multi_accelerator
This commit is contained in:
huismiling
2025-03-31 17:02:49 +08:00
committed by GitHub
parent ad63d20dff
commit d0b65bb479
4 changed files with 63 additions and 18 deletions

View File

@@ -103,6 +103,7 @@ from .utils import (
is_safetensors_available,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_xla_available,
@@ -2323,12 +2324,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
elif is_torch_mlu_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
" after initializing it on CPU with `model.to('mlu')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "