Use Python 3.9 syntax in examples (#37279)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-07 19:52:21 +08:00
committed by GitHub
parent 08f36771b3
commit 0fb8d49e88
123 changed files with 358 additions and 451 deletions

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn as nn
@@ -112,7 +112,7 @@ class Int8SymmetricConfig(QuantizationConfigMixin):
Configuration for INT8 symmetric quantization.
"""
def __init__(self, modules_to_not_convert: Optional[List[str]] = None, **kwargs):
def __init__(self, modules_to_not_convert: Optional[list[str]] = None, **kwargs):
self.quant_method = "int8_symmetric"
self.modules_to_not_convert = modules_to_not_convert
@@ -120,7 +120,7 @@ class Int8SymmetricConfig(QuantizationConfigMixin):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_diff_dict(self) -> Dict[str, Any]:
def to_diff_dict(self) -> dict[str, Any]:
config_dict = self.to_dict()
default_config_dict = Int8SymmetricConfig().to_dict()
@@ -164,7 +164,7 @@ class Int8SymmetricQuantizer(HfQuantizer):
model,
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
**kwargs,
):
module, tensor_name = get_module_from_name(model, param_name)
@@ -186,8 +186,8 @@ class Int8SymmetricQuantizer(HfQuantizer):
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
state_dict: dict[str, Any],
unexpected_keys: Optional[list[str]] = None,
):
"""
Quantizes weights to INT8 symmetric format.
@@ -202,7 +202,7 @@ class Int8SymmetricQuantizer(HfQuantizer):
module._buffers[tensor_name] = weight_quantized.to(target_device)
module._buffers["weight_scale"] = weight_scale.to(target_device)
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, Int8SymmetricLinear):