[Style] fix E721 warnings (#36474)

* fix E721 warnings

* config.hidden_size is not a tuple

* fix copies

* fix-copies

* not a tuple

* undo

* undo
This commit is contained in:
Kashif Rasul
2025-03-03 19:03:42 +01:00
committed by GitHub
parent 1975be4d97
commit 9fe82793ee
25 changed files with 37 additions and 37 deletions

View File

@@ -41,7 +41,7 @@
"from scipy import sparse\n",
"from torch import nn\n",
"\n",
"from transformers import *\n",
"from transformers import BertForQuestionAnswering\n",
"\n",
"\n",
"os.chdir(\"../../\")"
@@ -307,7 +307,7 @@
" print(f\"Skip {name}\")\n",
" continue\n",
"\n",
" if type(param) == torch.Tensor:\n",
" if isinstance(param, torch.Tensor):\n",
" if param.numel() == 1:\n",
" # module scale\n",
" # module zero_point\n",
@@ -319,13 +319,13 @@
" param = param.detach().numpy()\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" elif isinstance(param, (float, int, tuple)):\n",
" # float - tensor _packed_params.weight.scale\n",
" # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" elif isinstance(param, torch.dtype):\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
"\n",
@@ -370,7 +370,7 @@
" # print(f\"Skip {name}\")\n",
" # continue\n",
"\n",
" if type(param) == torch.Tensor:\n",
" if isinstance(param, torch.Tensor):\n",
" if param.numel() == 1:\n",
" # module scale\n",
" # module zero_point\n",
@@ -382,13 +382,13 @@
" param = param.detach().numpy()\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" elif isinstance(param, (float, int, tuple)):\n",
" # float - tensor _packed_params.weight.scale\n",
" # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" elif isinstance(param, torch.dtype):\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
"\n",
@@ -471,10 +471,10 @@
" assert name in reconstructed_elementary_qtz_st, name\n",
"\n",
"for name, param in reconstructed_elementary_qtz_st.items():\n",
" assert type(param) == type(elementary_qtz_st[name]), name\n",
" if type(param) == torch.Tensor:\n",
" assert isinstance(param, type(elementary_qtz_st[name])), name\n",
" if isinstance(param, torch.Tensor):\n",
" assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n",
" elif type(param) == np.ndarray:\n",
" elif isinstance(param, np.ndarray):\n",
" assert (param == elementary_qtz_st[name]).all(), name\n",
" else:\n",
" assert param == elementary_qtz_st[name], name"
@@ -532,10 +532,10 @@
" assert name in reconstructed_qtz_st, name\n",
"\n",
"for name, param in reconstructed_qtz_st.items():\n",
" assert type(param) == type(qtz_st[name]), name\n",
" if type(param) == torch.Tensor:\n",
" assert isinstance(param, type(qtz_st[name])), name\n",
" if isinstance(param, torch.Tensor):\n",
" assert torch.all(torch.eq(param, qtz_st[name])), name\n",
" elif type(param) == np.ndarray:\n",
" elif isinstance(param, np.ndarray):\n",
" assert (param == qtz_st[name]).all(), name\n",
" else:\n",
" assert param == qtz_st[name], name"