enable static cache on TP model (#39164)

* enable static cache on TP model

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* check tp size before init kv cache

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix docstring

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add tp tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix other cache head size

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2025-07-10 05:14:45 +08:00
committed by GitHub
parent 2ef59646b8
commit aff7df8436
4 changed files with 84 additions and 1 deletions

View File

@@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if tp_size is None:
tp_size = torch.distributed.get_world_size()
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",