Fix typing order (#39467)

* fix type order

* change all Union[str, dict] to Union[dict, str]

* add hf_parser test && fix test order

* add deepspeed dependency

* replace deepspeed with accelerator
This commit is contained in:
Qizhi Chen
2025-07-17 23:47:31 +08:00
committed by GitHub
parent bda75b4011
commit 73869f2e81
4 changed files with 32 additions and 25 deletions

View File

@@ -1395,7 +1395,7 @@ def _get_torch_dtype(
def _get_device_map(
model: "PreTrainedModel",
device_map: Optional[Union[str, dict]],
device_map: Optional[Union[dict, str]],
max_memory: Optional[dict],
hf_quantizer: Optional[HfQuantizer],
torch_dtype: Optional[torch.dtype],
@@ -2273,7 +2273,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return model
@classmethod
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]:
"""
Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern.
@@ -2321,7 +2321,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return attn_implementation
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
"""
Checks and dispatches to the requested attention implementation.
"""