[MLU] Fix FA2 check error, remove deepspeed-mlu deps. (#36159)
* add Cambricon MLUs support * fix mlu device rng state * up for quality check * up mlu to support fp16 * fix mlu device dependency error * fix mlu device dependency error * enable mlu device for bf16 * fix mlu device memory tracker * Cambricon support SDPA and flash_attn * MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu * Fix mlu FA2 check. Remove deepspeed-mlu check. add mlu tests support. * fix testing errors. * Merge branch 'hf/main' into main * fix get_device_count error. * fix mlu testing utils. * fix code quality and style. * switch to @require_torch_multi_accelerator
This commit is contained in:
@@ -103,6 +103,7 @@ from .utils import (
|
||||
is_safetensors_available,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_mlu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_xla_available,
|
||||
@@ -2323,12 +2324,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
elif is_torch_mlu_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
|
||||
" after initializing it on CPU with `model.to('mlu')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
|
||||
|
||||
Reference in New Issue
Block a user