Revert "remove dtensors, not explicit (#39840)" (#39912)

* Revert "remove dtensors, not explicit (#39840)"
This did not work with generation (lm_head needs extra care!)
This reverts commit 6dfd561d9c.

* update

* style?
This commit is contained in:
Arthur
2025-08-05 15:12:14 +02:00
committed by GitHub
parent 2589a52c5c
commit 20ce210ab7
4 changed files with 77 additions and 73 deletions

View File

@@ -101,6 +101,14 @@ class TestTensorParallel(TestCasePlus):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier()
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
prompt = "Can I help"
@@ -110,8 +118,7 @@ class TestTensorParallel(TestCasePlus):
next_token_logits = outputs[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
response = tokenizer.decode(next_token)
print(response)
# assert response == "with"
assert response == "with"
torch.distributed.barrier()
torch.distributed.destroy_process_group()
@@ -136,6 +143,14 @@ class TestTensorParallel(TestCasePlus):
model.forward = torch.compile(model.forward)
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
assert has_dtensor == 1, "TP model must has DTensor"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"