enable static cache on TP model (#39164)
* enable static cache on TP model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check tp size before init kv cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix docstring Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tp tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix other cache head size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus):
|
||||
assert torch.allclose(unpacked_weights, original_packed_weights)
|
||||
|
||||
|
||||
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
class TestTensorParallel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
@@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus):
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
|
||||
def test_model_generate(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
import torch
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
model.forward = torch.compile(model.forward)
|
||||
|
||||
has_dtensor = 0
|
||||
for name, parameter in model.named_parameters():
|
||||
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
|
||||
has_dtensor = 1
|
||||
break
|
||||
|
||||
assert has_dtensor == 1, "TP model must has DTensor"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
prompt = "Can I help"
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
|
||||
|
||||
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
|
||||
@require_huggingface_hub_greater_or_equal("0.31.4")
|
||||
def test_model_save(self):
|
||||
from safetensors import safe_open
|
||||
|
||||
Reference in New Issue
Block a user