Fix the incorrect permutation of gguf (#31788)
* Fix the incorrect permutation of gguf * rename num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * add typing to num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * rename variables * refactor permute function name * update the expected text of the llama3 q4 test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -147,10 +149,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
|||||||
|
|
||||||
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
|
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
|
||||||
num_heads = parsed_parameters["config"]["num_attention_heads"]
|
num_heads = parsed_parameters["config"]["num_attention_heads"]
|
||||||
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
|
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
|
||||||
weights = weights.reshape(tmp_shape)
|
if ".attn_q." in name:
|
||||||
weights = weights.transpose(0, 2, 1, 3)
|
weights = reverse_permute_weights(weights, num_heads, num_heads)
|
||||||
weights = weights.reshape(shape[::-1])
|
elif ".attn_k." in name:
|
||||||
|
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
|
||||||
|
|
||||||
for tensor_name in tensor_key_mapping:
|
for tensor_name in tensor_key_mapping:
|
||||||
if tensor_name in name:
|
if tensor_name in name:
|
||||||
@@ -163,3 +166,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
|||||||
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
|
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
|
||||||
|
|
||||||
return parsed_parameters
|
return parsed_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
|
||||||
|
# Original permutation implementation
|
||||||
|
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
|
||||||
|
if num_kv_heads is not None and n_head != num_kv_heads:
|
||||||
|
n_head = num_kv_heads
|
||||||
|
|
||||||
|
dim = weights.shape[0] // n_head // 2
|
||||||
|
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
|
||||||
|
return w.swapaxes(2, 1).reshape(weights.shape)
|
||||||
|
|||||||
@@ -188,8 +188,7 @@ class GgufIntegrationTests(unittest.TestCase):
|
|||||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||||
out = model.generate(**text, max_new_tokens=10)
|
out = model.generate(**text, max_new_tokens=10)
|
||||||
|
|
||||||
EXPECTED_TEXT = "Hello, I am new to this forum. I am"
|
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
|
||||||
|
|
||||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||||
|
|
||||||
def test_tokenization_xnli(self):
|
def test_tokenization_xnli(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user