Commit affe415f by zhengzifu

更新配置文件,添加对logits的支持,修改权重处理逻辑,注释掉部分运行函数以便于调试。

parent b9aaceb8
......@@ -16,6 +16,7 @@ class CFG:
self.mode = "run" # "test" or "run"
self.run_weights_batch = [
"logits.pkl",
"down.pkl",
"up.pkl",
"gate.pkl",
......@@ -25,6 +26,7 @@ class CFG:
"q.pkl",
]
self.group_numbers_batch = [
5 * 43,
8,
32,
32,
......@@ -34,13 +36,14 @@ class CFG:
32,
]
self.run_weights = "down.pkl" # 用于赋值
self.run_weights = "logits.pkl"
self.group_number = 32
self.safetensors_path = "001-H-LLM/qwen/model.safetensors"
self.safetensors_path = "001-H-LLM/qwen0414/model.safetensors"
self.npz_path = "001-H-LLM/collected_weights_20250408_solveequation.npz"
self.weights_dir = "001-H-LLM/qwen/weights"
self.mapped_weights_dir = "001-H-LLM/qwen/mapped_weights"
self.weights_dir = "001-H-LLM/qwen0414/weights"
self.mapped_weights_dir = "001-H-LLM/qwen0414/mapped_weights"
self.num_workers = 64
self.value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]
......
......@@ -41,6 +41,12 @@ def run(config: CFG):
os.makedirs(config.weights_dir, exist_ok=True)
print("Start generating quant weights")
file_path = config.safetensors_path
# 处理logits
weights_npz = np.load(config.npz_path)
weight_logit = [weights_npz["Weight:Logit"]]
weights["logits"] = weight_logit
with safetensors.safe_open(file_path, framework="pt") as f:
# 首先收集所有键并按层号排序
all_keys = [(key, get_layer_number(key)) for key in f.keys()]
......@@ -48,6 +54,7 @@ def run(config: CFG):
for key, _ in tqdm(sorted_keys):
# print(key)
# print(f.get_tensor(key).shape)
if not is_substring_in_list(key, name_dict.keys()):
continue
tensor = f.get_tensor(key)
......@@ -73,4 +80,5 @@ def run(config: CFG):
f.write(f"Min value: {np.min(weights[k])}\n")
f.write(f"Mean value: {np.mean(weights[k])}\n")
f.write(f"Standard deviation: {np.std(weights[k])}\n")
print("Quant weights generated at", config.weights_dir)
......@@ -55,7 +55,7 @@ def run_weights_preprocess(config: CFG):
import hllm.eda.generate_quant_weights as generate_quant_weights
import hllm.eda.mapping_weights as generate_mapping_weights
generate_quant_weights.run(config=config)
# generate_quant_weights.run(config=config)
generate_mapping_weights.run(config=config)
......@@ -81,6 +81,6 @@ if __name__ == "__main__":
config = CFG()
# run_weights_preprocess(config)
# run_origin(config)
run_optimized(config)
# run_optimized(config)
# run_verify()
# batch_run(config)
batch_run(config)
- 增加对于 logits 的支持
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment