tp plan should not be NONE (#38255)
* accept custom device_mesh * fix device_map * assert that num_heads % tp_size == 0 * todo. * ReplicateParallel * handle tied weights * handle dtensor in save_pretrained with safe_serialization * tp test works * doesnt work * fix shard_and_distribute_module's rank should be local_rank * tp=4 is correct * dp+tp is broken * todo allreduce with dtensors on another dim is annoying * workaround to sync dp grads when using dtensors * loading a checkpoint works * wandb and compare losses with different tp/dp * cleaning * cleaning * . * . * logs * CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention * DP=2 TP=2 now works even with tied embeddings * model.parameters() and model.module.parameters() are empty.. * reformat sanity_check_tensor_sync * set atol=1e-4 for CP to pass * try populate _parameters from named_modules * refactors TP2 DP2 works CP2 DP2 works * is_causal=True and pack sequences, no attn mask, and preshuffle dataset * fix packing * CP=4 doesn't work * fix labels and position_ids for CP * DP CP works with transformers 🥳🥳🥳 * refactor * add example cp * fixup * revert sdpa changes * example cleared * add CP, DP to the mesh init * nit * clean * use `ALL_PARALLEL_STYLES` * style * FSDP works * log on 1 rank * . * fix? * FSDP1 also has .parameters() bug * reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay * . * style and fixup * move stuff around * fix tests * style * let's make it a check * add missing licences * warning should be an info * tp plan should not be NONE * test all * god damn it * test all --------- Co-authored-by: nouamanetazi <nouamane98@gmail.com>
This commit is contained in:
@@ -62,8 +62,9 @@ from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
ALL_PARALLEL_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
initialize_tensor_parallelism,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
@@ -797,7 +798,7 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
int(os.environ["RANK"]), # the rank
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
@@ -1964,9 +1965,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
if v not in SUPPORTED_TP_STYLES:
|
||||
if v not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
@@ -3559,6 +3560,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# TODO: fix safe_serialization for tied weights
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
# We're going to remove aliases before saving
|
||||
ptrs = collections.defaultdict(list)
|
||||
@@ -4040,6 +4042,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@@ -4137,6 +4141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
@@ -4172,59 +4177,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device
|
||||
device_mesh = None
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
if device_type == "cuda":
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "cpu":
|
||||
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
|
||||
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
|
||||
elif device_type == "xpu":
|
||||
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
||||
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "hpu":
|
||||
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
|
||||
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`"
|
||||
) from e
|
||||
|
||||
# Get device with index assuming equal number of devices per host
|
||||
if device_type == "xpu":
|
||||
index = torch.xpu.current_device()
|
||||
elif device_type == "hpu":
|
||||
index = torch.hpu.current_device()
|
||||
if device_mesh is None and tp_plan is not None:
|
||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
||||
else:
|
||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world when tp_size not provided
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
# TODO: make device_mesh support multiple dimensions
|
||||
if device_mesh.ndim == 1:
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@@ -5142,7 +5102,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
os.environ["RANK"],
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user