upgrade sentencepiece version (#13564)

This commit is contained in:
elishowk
2021-09-15 15:25:03 +02:00
committed by GitHub
parent e86c02ea90
commit c783e14887
5 changed files with 104 additions and 93 deletions

View File

@@ -44,7 +44,7 @@
"\n",
"from transformers import *\n",
"\n",
"os.chdir('../../')"
"os.chdir(\"../../\")"
]
},
{
@@ -70,15 +70,15 @@
"# Load fine-pruned model and quantize the model\n",
"\n",
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
"model.to('cpu')\n",
"model.to(\"cpu\")\n",
"\n",
"quantized_model = torch.quantization.quantize_dynamic(\n",
" model=model,\n",
" qconfig_spec = {\n",
" nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" },\n",
" dtype=torch.qint8,\n",
" )\n",
" model=model,\n",
" qconfig_spec={\n",
" nn.Linear: torch.quantization.default_dynamic_qconfig,\n",
" },\n",
" dtype=torch.qint8,\n",
")\n",
"# print(quantized_model)\n",
"\n",
"qtz_st = quantized_model.state_dict()"
@@ -92,10 +92,14 @@
"source": [
"# Saving the original (encoder + classifier) in the standard torch.save format\n",
"\n",
"dense_st = {name: param for name, param in model.state_dict().items() \n",
" if \"embedding\" not in name and \"pooler\" not in name}\n",
"torch.save(dense_st, 'dbg/dense_squad.pt',)\n",
"dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")\n"
"dense_st = {\n",
" name: param for name, param in model.state_dict().items() if \"embedding\" not in name and \"pooler\" not in name\n",
"}\n",
"torch.save(\n",
" dense_st,\n",
" \"dbg/dense_squad.pt\",\n",
")\n",
"dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")"
]
},
{
@@ -198,23 +202,23 @@
" if \"dtype\" not in name and param.is_quantized:\n",
" print(\"Decompose quantization for\", name)\n",
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
" zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n",
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
" zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n",
" elementary_qtz_st[f\"{name}.scale\"] = scale\n",
" elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n",
"\n",
" # We assume the int_repr is sparse and compute its CSR representation\n",
" # Only the FCs in the encoder are actually sparse\n",
" int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n",
" int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n",
" int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n",
" int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n",
"\n",
" elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n",
" elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n",
" assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n",
" elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n",
" elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n",
" elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n",
" elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n",
" assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n",
" elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n",
" elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n",
" else:\n",
" elementary_qtz_st[name] = param\n"
" elementary_qtz_st[name] = param"
]
},
{
@@ -225,7 +229,7 @@
"source": [
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
"str_2_dtype = {\"qint8\": torch.qint8}\n",
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
"dtype_2_str = {torch.qint8: \"qint8\"}"
]
},
{
@@ -246,11 +250,17 @@
"source": [
"# Saving the pruned (encoder + classifier) in the standard torch.save format\n",
"\n",
"dense_optimized_st = {name: param for name, param in elementary_qtz_st.items() \n",
" if \"embedding\" not in name and \"pooler\" not in name}\n",
"torch.save(dense_optimized_st, 'dbg/dense_squad_optimized.pt',)\n",
"print(\"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
" round(os.path.getsize(\"dbg/dense_squad_optimized.pt\")/1e6, 2))\n"
"dense_optimized_st = {\n",
" name: param for name, param in elementary_qtz_st.items() if \"embedding\" not in name and \"pooler\" not in name\n",
"}\n",
"torch.save(\n",
" dense_optimized_st,\n",
" \"dbg/dense_squad_optimized.pt\",\n",
")\n",
"print(\n",
" \"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
" round(os.path.getsize(\"dbg/dense_squad_optimized.pt\") / 1e6, 2),\n",
")"
]
},
{
@@ -287,7 +297,7 @@
"# Save the decomposed state_dict with an HDF5 file\n",
"# Saving only the encoder + QA Head\n",
"\n",
"with h5py.File('dbg/squad_sparse.h5','w') as hf:\n",
"with h5py.File(\"dbg/squad_sparse.h5\", \"w\") as hf:\n",
" for name, param in elementary_qtz_st.items():\n",
" if \"embedding\" in name:\n",
" print(f\"Skip {name}\")\n",
@@ -318,18 +328,18 @@
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
"\n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"with open(\"dbg/metadata.json\", \"w\") as f:\n",
" f.write(json.dumps(qtz_st._metadata))\n",
"\n",
"size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
"print(\"\")\n",
"print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size/1e6, 2))\n",
"print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size/1e6, 2))\n"
"print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size / 1e6, 2))\n",
"print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size / 1e6, 2))"
]
},
{
@@ -350,15 +360,15 @@
"# Save the decomposed state_dict to HDF5 storage\n",
"# Save everything in the architecutre (embedding + encoder + QA Head)\n",
"\n",
"with h5py.File('dbg/squad_sparse_with_embs.h5','w') as hf:\n",
"with h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"w\") as hf:\n",
" for name, param in elementary_qtz_st.items():\n",
"# if \"embedding\" in name:\n",
"# print(f\"Skip {name}\")\n",
"# continue\n",
" # if \"embedding\" in name:\n",
" # print(f\"Skip {name}\")\n",
" # continue\n",
"\n",
"# if \"pooler\" in name:\n",
"# print(f\"Skip {name}\")\n",
"# continue\n",
" # if \"pooler\" in name:\n",
" # print(f\"Skip {name}\")\n",
" # continue\n",
"\n",
" if type(param) == torch.Tensor:\n",
" if param.numel() == 1:\n",
@@ -381,17 +391,16 @@
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
"\n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
"\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"with open(\"dbg/metadata.json\", \"w\") as f:\n",
" f.write(json.dumps(qtz_st._metadata))\n",
"\n",
"size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
"print('\\nSize (MB):', round(size/1e6, 2))\n"
"print(\"\\nSize (MB):\", round(size / 1e6, 2))"
]
},
{
@@ -411,10 +420,10 @@
"\n",
"reconstructed_elementary_qtz_st = {}\n",
"\n",
"hf = h5py.File('dbg/squad_sparse_with_embs.h5','r')\n",
"hf = h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"r\")\n",
"\n",
"for attr_name, attr_param in hf.attrs.items():\n",
" if 'shape' in attr_name:\n",
" if \"shape\" in attr_name:\n",
" attr_param = tuple(attr_param)\n",
" elif \".scale\" in attr_name:\n",
" if \"_packed_params\" in attr_name:\n",
@@ -430,20 +439,20 @@
" attr_param = str_2_dtype[attr_param]\n",
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
" # print(f\"Unpack {attr_name}\")\n",
" \n",
"\n",
"# Get the tensors/arrays\n",
"for data_name, data_param in hf.items():\n",
" if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n",
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
" elif \"embedding\" in data_name:\n",
" reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
" else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
" else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
" data_param = np.array(data_param)\n",
" if \"indices\" in data_name:\n",
" data_param = np.array(data_param, dtype=np.int32)\n",
" reconstructed_elementary_qtz_st[data_name] = data_param\n",
" # print(f\"Unpack {data_name}\")\n",
" \n",
"\n",
"\n",
"hf.close()"
]
@@ -484,27 +493,29 @@
"for name, param in reconstructed_elementary_qtz_st.items():\n",
" if \"weight.int_repr.indptr\" in name:\n",
" prefix_ = name[:-16]\n",
" data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
" indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
" data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
" indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
" indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n",
" shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
" shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
"\n",
" int_repr = sparse.csr_matrix(arg1=(data, indices, indptr),\n",
" shape=shape)\n",
" int_repr = sparse.csr_matrix(arg1=(data, indices, indptr), shape=shape)\n",
" int_repr = torch.tensor(int_repr.todense())\n",
"\n",
" scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n",
" zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n",
" weight = torch._make_per_tensor_quantized_tensor(int_repr,\n",
" scale,\n",
" zero_point)\n",
" weight = torch._make_per_tensor_quantized_tensor(int_repr, scale, zero_point)\n",
"\n",
" reconstructed_qtz_st[f\"{prefix_}\"] = weight\n",
" elif \"int_repr.data\" in name or \"int_repr.shape\" in name or \"int_repr.indices\" in name or \\\n",
" \"weight.scale\" in name or \"weight.zero_point\" in name:\n",
" elif (\n",
" \"int_repr.data\" in name\n",
" or \"int_repr.shape\" in name\n",
" or \"int_repr.indices\" in name\n",
" or \"weight.scale\" in name\n",
" or \"weight.zero_point\" in name\n",
" ):\n",
" continue\n",
" else:\n",
" reconstructed_qtz_st[name] = param\n"
" reconstructed_qtz_st[name] = param"
]
},
{
@@ -556,17 +567,17 @@
"source": [
"# Load the re-constructed state dict into a model\n",
"\n",
"dummy_model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')\n",
"dummy_model.to('cpu')\n",
"dummy_model = BertForQuestionAnswering.from_pretrained(\"bert-base-uncased\")\n",
"dummy_model.to(\"cpu\")\n",
"\n",
"reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n",
" model=dummy_model,\n",
" qconfig_spec = None,\n",
" dtype=torch.qint8,\n",
" )\n",
" model=dummy_model,\n",
" qconfig_spec=None,\n",
" dtype=torch.qint8,\n",
")\n",
"\n",
"reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n",
"with open('dbg/metadata.json', 'r') as read_file:\n",
"with open(\"dbg/metadata.json\", \"r\") as read_file:\n",
" metadata = json.loads(read_file.read())\n",
"reconstructed_qtz_st._metadata = metadata\n",
"\n",
@@ -596,8 +607,8 @@
" mask = torch.ones(size=(N, 128))\n",
"\n",
" y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n",
" y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
" \n",
" y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
"\n",
" assert torch.all(torch.eq(y, y_reconstructed))\n",
"print(\"Sanity check passed\")"
]