xpu: support xpu backend from stock pytorch (>=2.4) (#31238)
* xpu: support xpu backend from stock pytorch (>=2.4) Fixes: https://github.com/huggingface/transformers/issues/31237 XPU backend is available in the stock PyTorch starting from version 2.4, see [1]. This commit extends huggingface transformers to support XPU from both IPEX and the stock pytorch. IPEX is being tried first. See: https://github.com/pytorch/pytorch/issues/114842 Requires: https://github.com/huggingface/accelerate/pull/2825 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py passed to the test runner: import torch DEVICE_NAME = 'xpu' MANUAL_SEED_FN = torch.xpu.manual_seed EMPTY_CACHE_FN = torch.xpu.empty_cache DEVICE_COUNT_FN = torch.xpu.device_count Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> --------- Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
@@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
@@ -219,7 +218,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
with autocast(enabled=False):
|
||||
with torch.amp.autocast(query.device.type, enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
||||
@@ -25,7 +25,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@@ -249,7 +248,7 @@ class GPT2Attention(nn.Module):
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
with autocast(enabled=False):
|
||||
with torch.amp.autocast(query.device.type, enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
||||
@@ -813,23 +813,24 @@ def require_torch_multi_npu(test_case):
|
||||
|
||||
def require_torch_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires XPU and IPEX.
|
||||
Decorator marking a test that requires XPU (in PyTorch).
|
||||
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
|
||||
version.
|
||||
These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
|
||||
PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
|
||||
must match match current PyTorch version.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
|
||||
return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-XPU setup with IPEX and at least one XPU device. These tests are
|
||||
skipped on a machine without IPEX or multiple XPUs.
|
||||
Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
|
||||
multiple XPUs.
|
||||
|
||||
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
|
||||
"""
|
||||
if not is_torch_xpu_available():
|
||||
return unittest.skip("test requires IPEX and at least one XPU device")(test_case)
|
||||
return unittest.skip("test requires PyTorch XPU")(test_case)
|
||||
|
||||
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from .utils import (
|
||||
ExplicitEnum,
|
||||
cached_property,
|
||||
is_accelerate_available,
|
||||
is_ipex_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
@@ -2136,6 +2137,8 @@ class TrainingArguments:
|
||||
if self.use_cpu:
|
||||
device = torch.device("cpu")
|
||||
elif is_torch_xpu_available():
|
||||
if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"):
|
||||
raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")
|
||||
device = torch.device("xpu:0")
|
||||
torch.xpu.set_device(device)
|
||||
elif is_torch_mlu_available():
|
||||
|
||||
@@ -747,13 +747,18 @@ def is_ipex_available():
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available(check_device=False):
|
||||
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
|
||||
if not is_ipex_available():
|
||||
"""
|
||||
Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or
|
||||
via stock PyTorch (>=2.4) and potentially if a XPU is in the environment
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import torch
|
||||
|
||||
if is_ipex_available():
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
# Will raise a RuntimeError if no XPU is found
|
||||
|
||||
Reference in New Issue
Block a user