Fix TP initialization (#35860)
* fix tp * Update modeling_utils.py * style * style * Update test_tp.py * Update test_tp.py * style * Update test_tp.py * Update test_tp.py * Update test_tp.py * Update test_tp.py
This commit is contained in:
@@ -3443,6 +3443,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
|
||||||
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
|
||||||
|
|
||||||
|
if tp_plan is not None and device_map is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||||
|
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
|
||||||
|
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
|
||||||
|
# And temporarily setting the default device to current process rank result in the following error
|
||||||
|
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
|
||||||
|
tp_device = None
|
||||||
|
if tp_plan is not None:
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||||
|
|
||||||
|
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||||
|
device_type = torch._C._get_accelerator().type
|
||||||
|
device_module = torch.get_device_module(device_type)
|
||||||
|
# Get device with index assuming equal number of devices per host
|
||||||
|
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
|
||||||
|
# This is the easiest way to dispatch to the current process device
|
||||||
|
device_map = tp_device
|
||||||
|
|
||||||
if is_fsdp_enabled():
|
if is_fsdp_enabled():
|
||||||
low_cpu_mem_usage = True
|
low_cpu_mem_usage = True
|
||||||
|
|
||||||
@@ -4090,7 +4113,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||||
tp_device = None
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
@@ -4106,16 +4128,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||||
)
|
)
|
||||||
init_contexts.append(init_empty_weights())
|
init_contexts.append(init_empty_weights())
|
||||||
elif tp_plan is not None:
|
|
||||||
if not torch.distributed.is_initialized():
|
|
||||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
|
||||||
|
|
||||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
|
||||||
device_type = torch._C._get_accelerator().type
|
|
||||||
device_module = torch.get_device_module(device_type)
|
|
||||||
# Get device with index assuming equal number of devices per host
|
|
||||||
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
|
|
||||||
init_contexts.append(tp_device)
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||||
init_contexts.append(set_quantized_state())
|
init_contexts.append(set_quantized_state())
|
||||||
@@ -4249,12 +4261,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if dtype_orig is not None:
|
if dtype_orig is not None:
|
||||||
torch.set_default_dtype(dtype_orig)
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
load_contexts = []
|
|
||||||
# Make sure we load onto targeted device
|
|
||||||
if tp_device is not None:
|
|
||||||
load_contexts.append(tp_device)
|
|
||||||
|
|
||||||
with ContextManagers(load_contexts):
|
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
missing_keys,
|
missing_keys,
|
||||||
|
|||||||
@@ -13,6 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
@@ -30,6 +33,22 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
class TestTensorParallel(TestCasePlus):
|
class TestTensorParallel(TestCasePlus):
|
||||||
|
def torchrun(self, script: str):
|
||||||
|
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necesary."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
|
||||||
|
tmp.write(script)
|
||||||
|
tmp.flush()
|
||||||
|
tmp.seek(0)
|
||||||
|
cmd = (
|
||||||
|
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}"
|
||||||
|
).split()
|
||||||
|
|
||||||
|
# Note that the subprocess will be waited for here, and raise an error if not successful
|
||||||
|
try:
|
||||||
|
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise Exception(f"The following error was captured: {e.stderr}")
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_tp(self):
|
def test_tp(self):
|
||||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||||
@@ -43,6 +62,42 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_loading_memory_consumption(self):
|
||||||
|
script_to_run = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.distributed.init_process_group("nccl", device_id=device)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
# The expected full model memory footprint
|
||||||
|
expected_model_memory = 16
|
||||||
|
overhead_factor = 1.2
|
||||||
|
|
||||||
|
# Assert we did not use more than the full model expected memory (with some overhead)
|
||||||
|
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
|
||||||
|
raise ValueError("Loading the model used more than the full model size")
|
||||||
|
|
||||||
|
# Assert we correctly handled the sharding between devices
|
||||||
|
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
|
||||||
|
raise ValueError("Each model shard is larger than what is expected.")
|
||||||
|
|
||||||
|
torch.distributed.barrier()
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.torchrun(script_to_run)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||||
|
|||||||
Reference in New Issue
Block a user