Fixing the docs corresponding to the breaking change in torch 2.6. (#36420)
This commit is contained in:
@@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|||||||
# Initialize distributed
|
# Initialize distributed
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.distributed.init_process_group("nccl", device_id=device)
|
torch.distributed.init_process_group("nccl", device_id=device)
|
||||||
|
|
||||||
# Retrieve tensor parallel model
|
# Retrieve tensor parallel model
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|||||||
# 初始化分布式环境
|
# 初始化分布式环境
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
torch.distributed.init_process_group("nccl", device_id=device)
|
torch.distributed.init_process_group("nccl", device_id=device)
|
||||||
|
|
||||||
# 获取支持张量并行的模型
|
# 获取支持张量并行的模型
|
||||||
|
|||||||
Reference in New Issue
Block a user