enable static cache on TP model (#39164)
* enable static cache on TP model Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check tp size before init kv cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix docstring Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tp tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix other cache head size Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -4494,6 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
if tp_size is None:
|
||||
tp_size = torch.distributed.get_world_size()
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
|
||||
Reference in New Issue
Block a user