Commit b9aaceb8 by zhengzifu

完成了对fsm wrapper的修改

parent 08fb0edf
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
GenerateBlock,
ForBlock,
add_parameter,
add_input,
add_output,
add_genvar,
add_assign,
add_wire,
add_body,
add_instance,
add_newline,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import shutil
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
H,
L,
value_range,
WW,
VN,
group_number,
weights_file_name,
):
L_width = int(np.ceil(np.log2(L)))
GP = int(VN / group_number)
with ModuleBlock(module_name) as module:
# 参数
add_parameter("H", 896)
add_parameter("L", 24)
add_parameter("VN", group_number)
add_parameter("TTVN", VN)
add_parameter("WP", 5)
add_parameter("AP", 8)
add_parameter("SCW", 19)
add_parameter("SCWB", 5)
add_parameter("TTW", 28)
# 输入输出
add_input("clk")
add_input("core_valid")
add_input("rstn")
add_input("LM_sel", 24)
add_input("Top_in", 24, 8)
add_output("WT_result_acc", VN, 28)
add_output("result_valid")
# 内部连线
add_wire("result_valid_all", VN)
for i in range(VN):
# 生成FSM
fsm_params = {
"H": "H",
"L": "L",
"VN": "VN",
"WP": "WP",
"AP": "AP",
"SCW": "SCW",
"SCWB": "SCWB",
"TTW": "TTW",
}
fsm_ports = {
"clk": "clk",
"valid": "core_valid",
"fsm_rstn": "rstn",
"LM_sel": "LM_sel",
"Top_in": "Top_in",
"WT_result_acc": f"WT_result_acc[{i}*VN+:{i + 1}*VN]",
"result_valid": f"result_valid_all_{i}",
}
add_instance(
f"FSM_tp_{weights_file_name}_gp_{GP}",
f"inst_FSM_tp_{weights_file_name}_gp_{GP}",
fsm_params,
fsm_ports,
)
add_body(
f"""
assign result_valid = |result_valid_all;
"""
)
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name=module_name,
H=H,
L=L,
value_range=config.value_range,
WW=WW,
VN=VN,
group_number=config.group_number,
weights_file_name=weights_file_name,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, VN, config
)
for i in range(1)
]
for future in tqdm(as_completed(futures), total=1):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
print("Files generated in", file_dir)
......@@ -35,18 +35,20 @@ def run_optimized(config: CFG):
import hllm.optimized.generate_wrappers as generate_wrappers
import hllm.optimized.generate_layer_mux as generate_layer_mux
import hllm.optimized.generate_fsm as generate_fsm
import hllm.optimized.generate_fsm_wrapper as generate_fsm_wrapper
config.output_dir = "outputs-qwen/optimized"
generate_info.run(name="info", config=config)
generate_mux_wrapper.run(name="Mux_wrapper", config=config)
generate_mux.run(name="Mux", config=config)
generate_sub_wrapper.run(name="Sub_wrapper", config=config)
generate_wt_group.run(name="WT_group", config=config)
generate_mid_wrapper.run(name="Mid_wrapper", config=config)
generate_wallace.run(name="SerialWallaceTree", config=config)
generate_wrappers.run(name="Wrappers", config=config)
generate_layer_mux.run(name="Layer_mux", config=config)
generate_fsm.run(name="FSM", config=config)
# generate_info.run(name="info", config=config)
# generate_mux_wrapper.run(name="Mux_wrapper", config=config)
# generate_mux.run(name="Mux", config=config)
# generate_sub_wrapper.run(name="Sub_wrapper", config=config)
# generate_wt_group.run(name="WT_group", config=config)
# generate_mid_wrapper.run(name="Mid_wrapper", config=config)
# generate_wallace.run(name="SerialWallaceTree", config=config)
# generate_wrappers.run(name="Wrappers", config=config)
# generate_layer_mux.run(name="Layer_mux", config=config)
# generate_fsm.run(name="FSM", config=config)
generate_fsm_wrapper.run(name="FSM_wrapper", config=config)
def run_weights_preprocess(config: CFG):
......@@ -77,8 +79,8 @@ def batch_run(config: CFG):
if __name__ == "__main__":
config = CFG()
run_weights_preprocess(config)
# run_weights_preprocess(config)
# run_origin(config)
# run_optimized(config)
run_optimized(config)
# run_verify()
# batch_run(config)
module FSM_wrapper_tp_k
#(
parameter H = 896, // hidden layer dim, down: 4864 other: 896
parameter L = 24, // layer num
parameter VN = 32, // vector num per group, down: 8 logits: 116 other: 32
module FSM_wrapper_tp_k #(
parameter H = 896, // hidden layer dim, down: 4864 other: 896
parameter L = 24, // layer num
parameter VN = 32, // vector num per group, down: 8 logits: 116 other: 32
parameter TTVN = 128, // total vector num, vary in kqv...
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 28 // MAC output total width, SCW + WP + 4
)
(
input clk,
input core_valid,
input rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
parameter TTW = 28 // MAC output total width, SCW + WP + 4
) (
input clk,
input core_valid,
input rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
// 下面的输出直接传出去,所以不是reg类型
output [TTW - 1 : 0] WT_result_acc[TTVN - 1 : 0],
output result_valid
);
endmodule
\ No newline at end of file
//result valid 是所有FSM的输出valid的或
wire [TTVN - 1 : 0] result_valid_all;
genvar i;
generate
for (i = 0; i < TTVN / VN; i = i + 1) begin : FSM_inst
FSM_tp_k_gp_i #(
.H(H),
.L(L),
.VN(VN),
.WP(WP),
.AP(AP),
.SCW(SCW),
.SCWB(SCWB),
.TTW(TTW)
) FSM_inst_gp_i (
.clk(clk),
.valid(core_valid),
.fsm_rstn(rstn),
.LM_sel(LM_sel),
.Top_in(Top_in),
.WT_result_acc(WT_result_acc[i*VN+:VN]),
.result_valid(result_valid_all[i*VN+:VN])
);
end
endgenerate
assign result_valid = |result_valid_all;
endmodule
- 增加生成前删除原有的文件,目前只改了 optimized 的 fsm
- 增加对于 logits 的支持
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment