Fix typing order (#39467)
* fix type order * change all Union[str, dict] to Union[dict, str] * add hf_parser test && fix test order * add deepspeed dependency * replace deepspeed with accelerator
This commit is contained in:
@@ -1395,7 +1395,7 @@ def _get_torch_dtype(
|
||||
|
||||
def _get_device_map(
|
||||
model: "PreTrainedModel",
|
||||
device_map: Optional[Union[str, dict]],
|
||||
device_map: Optional[Union[dict, str]],
|
||||
max_memory: Optional[dict],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
@@ -2273,7 +2273,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]:
|
||||
"""
|
||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
||||
if `attn_implementation` matches hf kernels pattern.
|
||||
@@ -2321,7 +2321,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
return attn_implementation
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
|
||||
def set_attention_implementation(self, attn_implementation: Union[dict, str]):
|
||||
"""
|
||||
Checks and dispatches to the requested attention implementation.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user