diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 47f3f0cf8d..7da09be841 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -217,13 +217,13 @@ def _gguf_parse_value(_value, data_type): return _value -def dequantize_q4_k(data): +def dequantize_q4_k(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 block_size = GGML_BLOCK_SIZES["Q4_K"] - num_blocks = len(data) // block_size + num_blocks = n_bytes // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) @@ -248,13 +248,13 @@ def dequantize_q4_k(data): return factors * qs2 - offsets -def dequantize_q4_0(data): +def dequantize_q4_0(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11 block_size = GGML_BLOCK_SIZES["Q4_0"] - num_blocks = len(data) // block_size + num_blocks = n_bytes // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) @@ -274,13 +274,13 @@ def dequantize_q4_0(data): return (scales * quants).astype(np.float32) -def dequantize_q6_k(data): +def dequantize_q6_k(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152 block_size = GGML_BLOCK_SIZES["Q6_K"] - num_blocks = len(data) // block_size + num_blocks = n_bytes // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) @@ -327,11 +327,11 @@ def dequantize_q6_k(data): ) -def dequantize_q8_0(data): +def dequantize_q8_0(data, n_bytes: int): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 block_size = GGML_BLOCK_SIZES["Q8_0"] - num_blocks = len(data) // block_size + num_blocks = n_bytes // block_size scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] @@ -339,12 +339,12 @@ def dequantize_q8_0(data): return scales * qs -def dequantize_q2_k(data): +def dequantize_q2_k(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74 - num_blocks = len(data) // GGML_BLOCK_SIZES["Q2_K"] + num_blocks = n_bytes // GGML_BLOCK_SIZES["Q2_K"] data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"]) @@ -379,12 +379,12 @@ def dequantize_q2_k(data): return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) -def dequantize_q3_k(data): +def dequantize_q3_k(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95 - num_blocks = len(data) // GGML_BLOCK_SIZES["Q3_K"] + num_blocks = n_bytes // GGML_BLOCK_SIZES["Q3_K"] data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"]) @@ -428,12 +428,12 @@ def dequantize_q3_k(data): ) -def dequantize_q5_k(data): +def dequantize_q5_k(data, n_bytes: int): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138 - num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_K"] + num_blocks = n_bytes // GGML_BLOCK_SIZES["Q5_K"] data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"]) @@ -487,25 +487,25 @@ def dequantize_q5_k(data): ) -def load_dequant_gguf_tensor(shape, ggml_type, data): +def load_dequant_gguf_tensor(shape, ggml_type, data, n_bytes): if ggml_type == GGML_TYPES["F32"]: values = data elif ggml_type == GGML_TYPES["F16"]: values = data elif ggml_type == GGML_TYPES["Q8_0"]: - values = dequantize_q8_0(data) + values = dequantize_q8_0(data, n_bytes) elif ggml_type == GGML_TYPES["Q4_0"]: - values = dequantize_q4_0(data) + values = dequantize_q4_0(data, n_bytes) elif ggml_type == GGML_TYPES["Q4_K"]: - values = dequantize_q4_k(data) + values = dequantize_q4_k(data, n_bytes) elif ggml_type == GGML_TYPES["Q6_K"]: - values = dequantize_q6_k(data) + values = dequantize_q6_k(data, n_bytes) elif ggml_type == GGML_TYPES["Q2_K"]: - values = dequantize_q2_k(data) + values = dequantize_q2_k(data, n_bytes) elif ggml_type == GGML_TYPES["Q3_K"]: - values = dequantize_q3_k(data) + values = dequantize_q3_k(data, n_bytes) elif ggml_type == GGML_TYPES["Q5_K"]: - values = dequantize_q5_k(data) + values = dequantize_q5_k(data, n_bytes) else: raise NotImplementedError( f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose" diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 3cf34eab58..0b1621b7bf 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -145,7 +145,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): shape = tensor.shape name = tensor.name - weights = load_dequant_gguf_tensor(shape=shape, ggml_type=tensor.tensor_type, data=tensor.data) + weights = load_dequant_gguf_tensor( + shape=shape, ggml_type=tensor.tensor_type, data=tensor.data, n_bytes=int(tensor.n_bytes) + ) if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): num_heads = parsed_parameters["config"]["num_attention_heads"]