Update-tp test (#35844)

* update test for now

* up

* cleanup

* update todo
This commit is contained in:
Arthur
2025-02-03 09:37:02 +01:00
committed by GitHub
parent 62db3e6ed6
commit 7eecdf2a86
2 changed files with 36 additions and 12 deletions

View File

@@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements) return torch.isin(elements, test_elements)
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
def translate_to_torch_parallel_style(style: str): def translate_to_torch_parallel_style(style: str):
""" """
In model configurations, we use a neutral type (string) to specify parallel In model configurations, we use a neutral type (string) to specify parallel

View File

@@ -17,6 +17,7 @@ import subprocess
import tempfile import tempfile
import textwrap import textwrap
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel
@@ -110,9 +111,8 @@ if __name__ == "__main__":
# Test settings # Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 4 bs = 1
seqlen = 64 seqlen = 4096
# Get distributed settings # Get distributed settings
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
@@ -124,23 +124,45 @@ if __name__ == "__main__":
# Get model config # Get model config
config = LlamaConfig.from_pretrained(model_id) config = LlamaConfig.from_pretrained(model_id)
# Shrink model size config.hidden_size = 2048
config.num_hidden_layers //= 8 config.attention_bias = False
config.vocab_size //= 8
# Instantiate model # Instantiate model
with device: with device:
model = LlamaModel(config) model = LlamaModel(config).to(dtype=torch.float16)
model.eval() model.eval()
# Tensor Parallel # Tensor Parallel
if world_size > 1: if world_size > 1:
model.tensor_parallel(device_mesh) model.tensor_parallel(device_mesh)
# Run model # Run model
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
with torch.no_grad():
out = model(inputs) # Test cuda graphing explicitly
with torch.cuda.device(device):
print("Cuda graphing")
with torch.no_grad():
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# CUDA Graph setup
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out = model(inputs)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(inputs)
for _ in range(2):
g.replay()
s.synchronize()
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size]) assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
# Test compile
with torch.no_grad():
out = model(inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead")
out = model(inputs)
out = model(inputs)