Commit 8c446544 by zhengzifu

更新主程序以引入文件列表生成函数,修改配置文件以调整权重和组数,优化FSM模块的参数设置,添加新的FSM生成脚本。

parent affe415f
......@@ -36,8 +36,8 @@ class CFG:
32,
]
self.run_weights = "logits.pkl"
self.group_number = 32
self.run_weights = "down.pkl"
self.group_number = 8
self.safetensors_path = "001-H-LLM/qwen0414/model.safetensors"
self.npz_path = "001-H-LLM/collected_weights_20250408_solveequation.npz"
......
......@@ -34,12 +34,12 @@ def generate_module(
module_name_suffix = f"_tp_{weights_file_name}_gp_{cur_GP}"
str = f"""
module FSM_tp_{weights_file_name}_gp_{cur_GP} #(
parameter H = 896, // hidden layer dim, 896
parameter L = 24, // layer num, 24
parameter VN = 32, // vector num, 896 or 3840(in FFN)
parameter H = {H}, // hidden layer dim, 896
parameter L = {L}, // layer num, 24
parameter VN = {config.group_number}, // vector num, 896 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 18, // related to global max WW, AP+log2(WW_max)
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 32 // MAC output total width, SCW + WP + 4
) (
......
......@@ -48,20 +48,20 @@ def generate_module(
add_parameter("AP", 8)
add_parameter("SCW", 19)
add_parameter("SCWB", 5)
add_parameter("TTW", 28)
add_parameter("TTW", 32)
# 输入输出
add_input("clk")
add_input("core_valid")
add_input("rstn")
add_input("LM_sel", 24)
add_input("Top_in", 24, 8)
add_output("WT_result_acc", VN, 28)
add_input("LM_sel", "L")
add_input("Top_in", "AP", "H")
add_output("WT_result_acc", "TTW", "VN")
add_output("result_valid")
# 内部连线
add_wire("result_valid_all", VN)
for i in range(VN):
for i in range(VN // group_number):
# 生成FSM
fsm_params = {
"H": "H",
......@@ -80,17 +80,17 @@ def generate_module(
"LM_sel": "LM_sel",
"Top_in": "Top_in",
"WT_result_acc": f"WT_result_acc[{i}*VN+:{i + 1}*VN]",
"result_valid": f"result_valid_all_{i}",
"result_valid": f"result_valid_all[{i}]",
}
add_instance(
f"FSM_tp_{weights_file_name}_gp_{GP}",
f"inst_FSM_tp_{weights_file_name}_gp_{GP}",
f"FSM_tp_{weights_file_name}_gp_{i}",
f"inst_FSM_tp_{weights_file_name}_gp_{i}",
fsm_params,
fsm_ports,
)
add_body(
f"""
assign result_valid = |result_valid_all;
assign result_valid = &result_valid_all;
"""
)
return module
......
......@@ -39,7 +39,7 @@ module FSM_tp_{weights_file_name}_gp_{cur_GP} #(
parameter VN = 32, // vector num, 896 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 18, // related to global max WW, AP+log2(WW_max)
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 32 // MAC output total width, SCW + WP + 4
) (
......
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
GenerateBlock,
ForBlock,
add_parameter,
add_input,
add_output,
add_genvar,
add_assign,
add_wire,
add_body,
add_instance,
add_newline,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import shutil
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
H,
L,
value_range,
WW,
VN,
group_number,
weights_file_name,
):
L_width = int(np.ceil(np.log2(L)))
GP = int(VN / group_number)
with ModuleBlock(module_name) as module:
# 参数
add_parameter("H", 896)
add_parameter("L", 24)
add_parameter("VN", group_number)
add_parameter("TTVN", VN)
add_parameter("WP", 5)
add_parameter("AP", 8)
add_parameter("SCW", 19)
add_parameter("SCWB", 5)
add_parameter("TTW", 32)
# 输入输出
add_input("clk")
add_input("core_valid")
add_input("rstn")
add_input("LM_sel", "L")
add_input("Top_in", "AP", "H")
add_output("WT_result_acc", "TTW", "VN")
add_output("result_valid")
# 内部连线
add_wire("result_valid_all", VN)
for i in range(VN):
# 生成FSM
fsm_params = {
"H": "H",
"L": "L",
"VN": "VN",
"WP": "WP",
"AP": "AP",
"SCW": "SCW",
"SCWB": "SCWB",
"TTW": "TTW",
}
fsm_ports = {
"clk": "clk",
"valid": "core_valid",
"fsm_rstn": "rstn",
"LM_sel": "LM_sel",
"Top_in": "Top_in",
"WT_result_acc": f"WT_result_acc[{i}*VN+:{i + 1}*VN]",
"result_valid": f"result_valid_all[{i}]",
}
add_instance(
f"FSM_tp_{weights_file_name}_gp_{GP}",
f"inst_FSM_tp_{weights_file_name}_gp_{GP}",
fsm_params,
fsm_ports,
)
add_body(
f"""
assign result_valid = &result_valid_all;
"""
)
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name=module_name,
H=H,
L=L,
value_range=config.value_range,
WW=WW,
VN=VN,
group_number=config.group_number,
weights_file_name=weights_file_name,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, VN, config
)
for i in range(1)
]
for future in tqdm(as_completed(futures), total=1):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
print("Files generated in", file_dir)
......@@ -24,3 +24,26 @@ def find_index(arr, target, epsilon=1e-3):
if min_diff < epsilon: # 如果最小差值在允许的误差范围内
return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引
raise ValueError("No match found") # 如果没有找到匹配项,则引发异常
def generate_filelist(config: CFG):
print(f"Generating filelist for {config.output_dir}")
filelist = []
def _scan_dir(dir_path):
files = []
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
if os.path.isfile(item_path) and item.endswith(".sv"):
files.append(os.path.abspath(item_path))
elif os.path.isdir(item_path):
files.extend(_scan_dir(item_path))
return files
filelist = _scan_dir(config.output_dir)
with open(os.path.join(config.output_dir, "filelist.tcl"), "w") as f:
for file in filelist:
f.write(f"{file}\n")
print(f"Filelist generated for {config.output_dir}")
return filelist
from hllm.config import CFG
from hllm.utils import generate_filelist
def run_origin(config: CFG):
......@@ -22,6 +23,7 @@ def run_origin(config: CFG):
generate_wallace.run(name="SerialWallaceTree", config=config)
generate_wrappers.run(name="Wrappers", config=config)
generate_wt_group.run(name="WT_group", config=config)
generate_filelist(config)
def run_optimized(config: CFG):
......@@ -38,16 +40,16 @@ def run_optimized(config: CFG):
import hllm.optimized.generate_fsm_wrapper as generate_fsm_wrapper
config.output_dir = "outputs-qwen/optimized"
# generate_info.run(name="info", config=config)
# generate_mux_wrapper.run(name="Mux_wrapper", config=config)
# generate_mux.run(name="Mux", config=config)
# generate_sub_wrapper.run(name="Sub_wrapper", config=config)
# generate_wt_group.run(name="WT_group", config=config)
# generate_mid_wrapper.run(name="Mid_wrapper", config=config)
# generate_wallace.run(name="SerialWallaceTree", config=config)
# generate_wrappers.run(name="Wrappers", config=config)
# generate_layer_mux.run(name="Layer_mux", config=config)
# generate_fsm.run(name="FSM", config=config)
generate_info.run(name="info", config=config)
generate_mux_wrapper.run(name="Mux_wrapper", config=config)
generate_mux.run(name="Mux", config=config)
generate_sub_wrapper.run(name="Sub_wrapper", config=config)
generate_wt_group.run(name="WT_group", config=config)
generate_mid_wrapper.run(name="Mid_wrapper", config=config)
generate_wallace.run(name="SerialWallaceTree", config=config)
generate_wrappers.run(name="Wrappers", config=config)
generate_layer_mux.run(name="Layer_mux", config=config)
generate_fsm.run(name="FSM", config=config)
generate_fsm_wrapper.run(name="FSM_wrapper", config=config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment