enable tp on CPU (#36299)

* enable tp on CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* get rank from cpu

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable TP tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* em print

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix model id

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix conflict

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix index and add doc

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2025-03-31 16:55:47 +08:00
committed by GitHub
parent 4705b04c74
commit 286393fbb1
3 changed files with 56 additions and 113 deletions

View File

@@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model(
"""
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map[""] not in ("cpu", torch.device("cpu")):
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
@@ -4110,24 +4111,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
tp_device = torch.device(device_type, torch.cuda.current_device())
if tp_device.index > 0:
# Get device with index assuming equal number of devices per host
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if index is not None and index > 0:
import sys
sys.stdout = open(os.devnull, "w")