Update-tp test (#35844)
* update test for now * up * cleanup * update todo
This commit is contained in:
@@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
|
|||||||
return torch.isin(elements, test_elements)
|
return torch.isin(elements, test_elements)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO need to add the __repr__ that shows that it is a colwise parallel
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/145726
|
||||||
def translate_to_torch_parallel_style(style: str):
|
def translate_to_torch_parallel_style(style: str):
|
||||||
"""
|
"""
|
||||||
In model configurations, we use a neutral type (string) to specify parallel
|
In model configurations, we use a neutral type (string) to specify parallel
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
@@ -110,9 +111,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Test settings
|
# Test settings
|
||||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
bs = 4
|
bs = 1
|
||||||
seqlen = 64
|
seqlen = 4096
|
||||||
|
|
||||||
# Get distributed settings
|
# Get distributed settings
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
@@ -124,23 +124,45 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Get model config
|
# Get model config
|
||||||
config = LlamaConfig.from_pretrained(model_id)
|
config = LlamaConfig.from_pretrained(model_id)
|
||||||
# Shrink model size
|
config.hidden_size = 2048
|
||||||
config.num_hidden_layers //= 8
|
config.attention_bias = False
|
||||||
config.vocab_size //= 8
|
|
||||||
|
|
||||||
# Instantiate model
|
# Instantiate model
|
||||||
with device:
|
with device:
|
||||||
model = LlamaModel(config)
|
model = LlamaModel(config).to(dtype=torch.float16)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Tensor Parallel
|
# Tensor Parallel
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model.tensor_parallel(device_mesh)
|
model.tensor_parallel(device_mesh)
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
|
|
||||||
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
||||||
with torch.no_grad():
|
|
||||||
out = model(inputs)
|
# Test cuda graphing explicitly
|
||||||
|
with torch.cuda.device(device):
|
||||||
|
print("Cuda graphing")
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
|
||||||
|
# CUDA Graph setup
|
||||||
|
s = torch.cuda.Stream(device=device)
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
for i in range(3):
|
||||||
|
out = model(inputs)
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
out = model(inputs)
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
g.replay()
|
||||||
|
s.synchronize()
|
||||||
|
|
||||||
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
|
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
|
||||||
|
|
||||||
|
# Test compile
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(inputs)
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead")
|
||||||
|
out = model(inputs)
|
||||||
|
out = model(inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user