Use Python 3.9 syntax in examples (#37279)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -112,7 +112,7 @@ class Int8SymmetricConfig(QuantizationConfigMixin):
|
||||
Configuration for INT8 symmetric quantization.
|
||||
"""
|
||||
|
||||
def __init__(self, modules_to_not_convert: Optional[List[str]] = None, **kwargs):
|
||||
def __init__(self, modules_to_not_convert: Optional[list[str]] = None, **kwargs):
|
||||
self.quant_method = "int8_symmetric"
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -120,7 +120,7 @@ class Int8SymmetricConfig(QuantizationConfigMixin):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
def to_diff_dict(self) -> Dict[str, Any]:
|
||||
def to_diff_dict(self) -> dict[str, Any]:
|
||||
config_dict = self.to_dict()
|
||||
default_config_dict = Int8SymmetricConfig().to_dict()
|
||||
|
||||
@@ -164,7 +164,7 @@ class Int8SymmetricQuantizer(HfQuantizer):
|
||||
model,
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
@@ -186,8 +186,8 @@ class Int8SymmetricQuantizer(HfQuantizer):
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Dict[str, Any],
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
state_dict: dict[str, Any],
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Quantizes weights to INT8 symmetric format.
|
||||
@@ -202,7 +202,7 @@ class Int8SymmetricQuantizer(HfQuantizer):
|
||||
module._buffers[tensor_name] = weight_quantized.to(target_device)
|
||||
module._buffers["weight_scale"] = weight_scale.to(target_device)
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, Int8SymmetricLinear):
|
||||
|
||||
Reference in New Issue
Block a user