enable tp on CPU (#36299)
* enable tp on CPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * get rank from cpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable TP tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * em print Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix model id Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix conflict Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix index and add doc Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model(
|
||||
"""
|
||||
tensor_device = "cpu"
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
if device_map[""] not in ("cpu", torch.device("cpu")):
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
if device_map is not None:
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
@@ -4110,24 +4111,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
if device_type == "cuda":
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "cpu":
|
||||
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
|
||||
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`"
|
||||
) from e
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
tp_device = torch.device(device_type, torch.cuda.current_device())
|
||||
if tp_device.index > 0:
|
||||
# Get device with index assuming equal number of devices per host
|
||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
Reference in New Issue
Block a user