enable generation fsdp/utils cases on XPU (#38009)
* enable generation fsdp/utils test cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> * xx Signed-off-by: Yao Matrix <matrix.yao@intel.com> * use backend_xx APIs Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -15,19 +15,29 @@
|
||||
import argparse
|
||||
from typing import Any, Callable
|
||||
|
||||
from transformers import is_torch_available, is_torch_mlu_available
|
||||
from transformers import is_torch_available, is_torch_xpu_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
backend_torch_accelerator_module,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_ccl_available, is_ipex_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
if is_torch_xpu_available():
|
||||
if is_ipex_available():
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
if is_ccl_available():
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch.distributed
|
||||
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
@@ -46,10 +56,7 @@ if is_torch_available():
|
||||
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
||||
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
device_count = backend_device_count(torch_device)
|
||||
torch.distributed.init_process_group(world_size=device_count)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -60,10 +67,8 @@ if is_torch_available():
|
||||
|
||||
@manage_process_group
|
||||
def fsdp_generate():
|
||||
if is_torch_mlu_available():
|
||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
else:
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
@@ -86,10 +91,8 @@ if is_torch_available():
|
||||
|
||||
@manage_process_group
|
||||
def fsdp2_generate():
|
||||
if is_torch_mlu_available():
|
||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
else:
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
@@ -114,10 +117,7 @@ if is_torch_available():
|
||||
class TestFSDPGeneration(TestCasePlus):
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp_generate(self):
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
device_count = backend_device_count(torch_device)
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
@@ -129,10 +129,8 @@ class TestFSDPGeneration(TestCasePlus):
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp2_generate(self):
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
device_count = backend_device_count(torch_device)
|
||||
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
|
||||
Reference in New Issue
Block a user