* Revert "remove dtensors, not explicit (#39840)"
This did not work with generation (lm_head needs extra care!)
This reverts commit 6dfd561d9c.
* update
* style?
This commit is contained in:
@@ -101,6 +101,14 @@ class TestTensorParallel(TestCasePlus):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
has_dtensor = 0
|
||||
for name, parameter in model.named_parameters():
|
||||
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
|
||||
has_dtensor = 1
|
||||
break
|
||||
|
||||
assert has_dtensor == 1, "TP model must has DTensor"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False)
|
||||
prompt = "Can I help"
|
||||
|
||||
@@ -110,8 +118,7 @@ class TestTensorParallel(TestCasePlus):
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
response = tokenizer.decode(next_token)
|
||||
print(response)
|
||||
# assert response == "with"
|
||||
assert response == "with"
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
@@ -136,6 +143,14 @@ class TestTensorParallel(TestCasePlus):
|
||||
|
||||
model.forward = torch.compile(model.forward)
|
||||
|
||||
has_dtensor = 0
|
||||
for name, parameter in model.named_parameters():
|
||||
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
|
||||
has_dtensor = 1
|
||||
break
|
||||
|
||||
assert has_dtensor == 1, "TP model must has DTensor"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
prompt = "Can I help"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user