enable static cache on TP model (#39164)
* enable static cache on TP model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check tp size before init kv cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix docstring Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tp tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix other cache head size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -1098,6 +1098,10 @@ class StaticCache(Cache):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is split between different gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
tp_size (`Optional[int]`, *optional*):
|
||||
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
|
||||
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
|
||||
number of key/value heads will not be adjusted.
|
||||
|
||||
|
||||
Example:
|
||||
@@ -1130,6 +1134,7 @@ class StaticCache(Cache):
|
||||
device: Union[torch.device, str, None] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
@@ -1144,6 +1149,13 @@ class StaticCache(Cache):
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
if tp_size is not None and tp_size > 1:
|
||||
if self.num_key_value_heads % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
||||
)
|
||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||
self.num_key_value_heads //= tp_size
|
||||
|
||||
self.key_cache: list[torch.Tensor] = []
|
||||
self.value_cache: list[torch.Tensor] = []
|
||||
@@ -1573,6 +1585,10 @@ class HybridCache(Cache):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is split between different gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
tp_size (`Optional[int]`, *optional*):
|
||||
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
|
||||
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
|
||||
number of key/value heads will not be adjusted.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -1604,6 +1620,7 @@ class HybridCache(Cache):
|
||||
device: Union[torch.device, str, None] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
@@ -1627,6 +1644,13 @@ class HybridCache(Cache):
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
if tp_size is not None and tp_size > 1:
|
||||
if self.num_key_value_heads % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
||||
)
|
||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||
self.num_key_value_heads //= tp_size
|
||||
|
||||
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
||||
if hasattr(config, "layer_types"):
|
||||
@@ -2197,6 +2221,10 @@ class OffloadedStaticCache(StaticCache):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is split between different gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
tp_size (`Optional[int]`, *optional*):
|
||||
The tensor parallel size of the model. This is used to adjust the number of key/value heads in the cache
|
||||
if the model is using tensor parallelism. If not provided, it defaults to `None`, which means that the
|
||||
number of key/value heads will not be adjusted.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -2228,6 +2256,7 @@ class OffloadedStaticCache(StaticCache):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super(Cache, self).__init__()
|
||||
|
||||
@@ -2251,6 +2280,13 @@ class OffloadedStaticCache(StaticCache):
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
if tp_size is not None and tp_size > 1:
|
||||
if num_key_value_heads % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
||||
)
|
||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||
num_key_value_heads //= tp_size
|
||||
|
||||
cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
|
||||
|
||||
|
||||
@@ -1963,6 +1963,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
"device": device,
|
||||
"layer_device_map": layer_device_map,
|
||||
}
|
||||
if cache_implementation in ["static", "hybrid", "offloaded_static"]:
|
||||
cache_kwargs.update({"tp_size": self.tp_size})
|
||||
|
||||
self._cache = cache_cls(**cache_kwargs)
|
||||
if requires_cross_attention_cache:
|
||||
encoder_kwargs = cache_kwargs.copy()
|
||||
|
||||
@@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
if tp_size is None:
|
||||
tp_size = torch.distributed.get_world_size()
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Run the test: CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -62,7 +64,6 @@ class TestTensorParallelUtils(TestCasePlus):
|
||||
assert torch.allclose(unpacked_weights, original_packed_weights)
|
||||
|
||||
|
||||
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
|
||||
class TestTensorParallel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
@@ -125,6 +126,46 @@ class TestTensorParallel(TestCasePlus):
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
|
||||
def test_model_generate(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
import torch
|
||||
import os
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "JackFram/llama-68m"
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
model.forward = torch.compile(model.forward)
|
||||
|
||||
has_dtensor = 0
|
||||
for name, parameter in model.named_parameters():
|
||||
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
|
||||
has_dtensor = 1
|
||||
break
|
||||
|
||||
assert has_dtensor == 1, "TP model must has DTensor"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
prompt = "Can I help"
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
|
||||
|
||||
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
self.torchrun(script_to_run)
|
||||
|
||||
@require_huggingface_hub_greater_or_equal("0.31.4")
|
||||
def test_model_save(self):
|
||||
from safetensors import safe_open
|
||||
|
||||
Reference in New Issue
Block a user