Revert parallelism temporarily (#38240)
* Revert "Protect ParallelInterface" This reverts commitcb513e35f9. * Revert "parallelism goes brrr (#37877)" This reverts commit1c2f36b480. * Empty commit
This commit is contained in:
@@ -62,9 +62,8 @@ from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
ALL_PARALLEL_STYLES,
|
||||
SUPPORTED_TP_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
initialize_tensor_parallelism,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
@@ -798,7 +797,7 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
device_mesh.get_local_rank(),
|
||||
int(os.environ["RANK"]), # the rank
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
@@ -1965,9 +1964,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
if v not in ALL_PARALLEL_STYLES:
|
||||
if v not in SUPPORTED_TP_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
@@ -3560,7 +3559,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# TODO: fix safe_serialization for tied weights
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
# We're going to remove aliases before saving
|
||||
ptrs = collections.defaultdict(list)
|
||||
@@ -4042,8 +4040,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@@ -4141,7 +4137,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
@@ -4177,13 +4172,59 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device
|
||||
if device_mesh is None:
|
||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
||||
else:
|
||||
# TODO: make device_mesh support multiple dimensions
|
||||
if device_mesh.ndim == 1:
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
device_mesh = None
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
if device_type == "cuda":
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "cpu":
|
||||
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
|
||||
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
|
||||
elif device_type == "xpu":
|
||||
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
||||
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "hpu":
|
||||
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
|
||||
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`"
|
||||
) from e
|
||||
|
||||
# Get device with index assuming equal number of devices per host
|
||||
if device_type == "xpu":
|
||||
index = torch.xpu.current_device()
|
||||
elif device_type == "hpu":
|
||||
index = torch.hpu.current_device()
|
||||
else:
|
||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world when tp_size not provided
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@@ -5101,7 +5142,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
device_mesh.get_local_rank(),
|
||||
os.environ["RANK"],
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user