Commit 8edf8925 by zhengzifu

First Commit

parents
__pycache__
.DS_Store
.venv
.vscode
outputs
outputs-qwen
weights
model.safetensors
001-H-LLM
build
src-Optimize_HN
*.egg-info
*.egg
*.so
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from config import CFG
from tqdm import tqdm
import pickle
import json
path_dir = "Optimized_HN"
def calculate_WW(matrix: np.array, value_range):
WW = [0] * len(value_range)
for i in range(len(value_range)):
WW[i] = max(
[len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]
)
return WW
def find_index(arr, target, epsilon=1e-3):
arr = np.array(arr) # 转换为numpy数组
diff = np.abs(arr - target) # 计算差值数组
min_diff = np.min(diff) # 找到最小的差值
if min_diff < epsilon: # 如果最小差值在允许的误差范围内
return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引
raise ValueError("No match found") # 如果没有找到匹配项,则引发异常
# %%
def generate_module(
matrix,
module_name="HN",
H=16,
L=5,
value_range=[-1, 1],
WW=[8, 8],
CUR_VN=0,
weights_file_name="",
):
# with VerilogGenerator() as generator:
with ModuleBlock(f"{module_name}") as module:
add_parameter("H", H)
add_parameter("L", L)
# for i in range(len(value_range)):
# add_parameter(f"WW_{i}", WW[i])
add_input("HN_in", "H")
add_input("CST_LOW")
node_file = os.path.join(
"output", weights_file_name, f"{weights_file_name}_vn_{CUR_VN}.json"
)
node = json.load(open(node_file))["node"]
# max_L=0
# 内部连线
for i, hn_in_layers in enumerate(node):
color_file = os.path.join(
"output",
weights_file_name,
f"{weights_file_name}_vn_{CUR_VN}_value_{value_range[i]}.json",
)
mux_port = json.load(open(color_file))["color"]
max_mux_port = max(mux_port) + 1
add_parameter(f"WW_{i}", max_mux_port)
max_L = max(mux_port.count(x) for x in set(mux_port) if x != -1)
add_output(
f"HN_out_{i}",
f"WW_{i}",
f"{max_L}",
)
# 第一维是该颜色的使用次数 第二维是染的颜色即mux_port
hn_out = [[-1 for _ in range(max_mux_port)] for _ in range(max_L)]
used_mux_port = [0 for _ in range(max_mux_port)]
for j, hn_in_layer in enumerate(hn_in_layers):
if mux_port[j] == -1:
continue
hn_out[used_mux_port[mux_port[j]]][mux_port[j]] = j
used_mux_port[mux_port[j]] += 1
for j in range(max_L):
for k in range(max_mux_port):
if hn_out[j][k] == -1:
add_assign(f"HN_out_{i}", [j, k], "CST_LOW", [])
else:
add_assign(f"HN_out_{i}", [j, k], "HN_in", [hn_out[j][k]])
return module
# %%
def process_task(i, weights_file_name, matrix, H, L):
try:
WW = calculate_WW(matrix, CFG.value_range)
file_dir = os.path.join(path_dir, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(
file_dir, f"{path_dir}_tp_{weights_file_name}_vc_{i}.sv"
)
with open(file_name, "w") as f:
text = generate_module(
matrix,
module_name=f"{path_dir}_tp_{weights_file_name}_vc_{i}",
H=H,
L=L,
value_range=CFG.value_range,
WW=WW,
CUR_VN=i,
weights_file_name=weights_file_name,
).generate()
f.write(text)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run():
os.makedirs(path_dir, exist_ok=True)
for weights_file in os.listdir("mapped_weights"):
if weights_file != CFG.run_weights:
continue
weights_path = os.path.join("mapped_weights", weights_file)
weights_file_name = os.path.splitext(weights_file)[0]
print(f"Processing {weights_file_name}")
with open(weights_path, "rb") as f:
print(f"Loading {weights_file_name}")
matrixs = pickle.load(f)
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:
futures = [
executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
if __name__ == "__main__":
run()
# %%
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from tqdm import tqdm
import pickle
from hllm.config import CFG
# %%
def generate_module(
cur_GP=0,
module_name="",
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
WW=[8, 8],
config: CFG = None,
):
module_name_suffix = f"_tp_{weights_file_name}_gp_{cur_GP}"
str = f"""
module FSM_tp_{weights_file_name}_gp_{cur_GP} #(
parameter H = 1536, // hidden layer dim, 1536
parameter L = 52, // layer num, 52
parameter VN = 32, // vector num, 1536 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 17, // DG, related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // DG, log2(SCW)
parameter TTW = 26 // DG, MAC output total width, SCW + WP + 4
) (
input clk,
input valid,
input fsm_rstn,
input [6 - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);"""
str += r"""
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_rstn;
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
reg [SCWB : 0] idx;
reg [2 : 0] state;
reg CST_LOW;
genvar j;
integer i;"""
str += f"""
Mid_wrapper_tp_{weights_file_name}_gp_{cur_GP} #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
"""
str += r"""
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(fsm_rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
state <= 0;
idx <= 0;
tree_rstn <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
end else begin
if (state == 0) begin
idx <= 0;
tree_rstn <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
if (valid == 1) begin
state <= 1;
end
end
if (state == 1) begin
tree_rstn <= 1;
if (idx == 0) begin
TM_sel <= 8'b00000001;
end else begin
if (idx == 1) begin
TM_sel <= 8'b00000010;
end
if (idx == 2) begin
TM_sel <= 8'b00000100;
end
if (idx == 3) begin
TM_sel <= 8'b00001000;
end
if (idx == 4) begin
TM_sel <= 8'b00010000;
end
if (idx == 5) begin
TM_sel <= 8'b00100000;
end
if (idx == 6) begin
TM_sel <= 8'b01000000;
end
if (idx == 7) begin
TM_sel <= 8'b10000000;
end
if (idx > 7) begin
TM_sel <= 8'b10000000;
end // signed extension
if (idx == SCW) begin
state <= 2;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
idx <= idx + 1;
end
if (state == 2) begin
for (
i = 0; i < VN; i = i + 1
) begin
WT_result_v0[i] <= {1'b0, final_S_v0[i]} + {final_C_v0[i], 1'b0};
WT_result_v1[i] <= {1'b0, final_S_v1[i]} + {final_C_v1[i], 1'b0};
WT_result_v2[i] <= {1'b0, final_S_v2[i]} + {final_C_v2[i], 1'b0};
WT_result_v3[i] <= {1'b0, final_S_v3[i]} + {final_C_v3[i], 1'b0};
WT_result_v4[i] <= {1'b0, final_S_v4[i]} + {final_C_v4[i], 1'b0};
WT_result_v5[i] <= {1'b0, final_S_v5[i]} + {final_C_v5[i], 1'b0};
WT_result_v6[i] <= {1'b0, final_S_v6[i]} + {final_C_v6[i], 1'b0};
WT_result_v7[i] <= {1'b0, final_S_v7[i]} + {final_C_v7[i], 1'b0};
WT_result_v8[i] <= {1'b0, final_S_v8[i]} + {final_C_v8[i], 1'b0};
WT_result_v9[i] <= {1'b0, final_S_v9[i]} + {final_C_v9[i], 1'b0};
WT_result_v10[i] <= {1'b0, final_S_v10[i]} + {final_C_v10[i], 1'b0};
WT_result_v11[i] <= {1'b0, final_S_v11[i]} + {final_C_v11[i], 1'b0};
WT_result_v12[i] <= {1'b0, final_S_v12[i]} + {final_C_v12[i], 1'b0};
WT_result_v13[i] <= {1'b0, final_S_v13[i]} + {final_C_v13[i], 1'b0};
end
idx <= 0;
state <= 3;
tree_rstn <= 0;
end
if (state == 3) begin
// MAC adder
state <= 4;
end
if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
state <= 5;
result_valid <= 1;
end
if (state == 5) begin
// 输出及其他握手信号,待用
idx <= 0;
state <= 0;
tree_rstn <= 0;
end
end
end
endmodule
"""
return str
# %%
def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_gp_{i}"
module = generate_module(
i,
module_name=module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
)
with open(file_name, "w") as f:
f.write(module)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
GP = int(VN / config.group_number)
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, VN, config
)
for i in range(GP)
]
for future in tqdm(as_completed(futures), total=GP):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
import numpy as np
import pickle
import os
import sys
from config import CFG
if CFG.mode == "run":
exit
def run():
shape = CFG.test_weights_shape
weights = np.random.choice(CFG.value_range, shape)
filename_pkl = os.path.join(CFG.weights_dir, CFG.test_weights)
filename_txt = os.path.join(
CFG.weights_dir, CFG.test_weights.split(".")[0] + ".txt"
)
with open(filename_pkl, "wb") as f:
pickle.dump(weights, f)
with open(filename_txt, "w") as f:
for row in weights:
for val in row:
f.write(f"{val}\n")
f.write("\n\n")
CXX = g++
CXXFLAGS = -std=c++11 -Wall -pthread
TARGET = optimize_HN
SRC = optimize_HN.cpp
$(TARGET): $(SRC)
$(CXX) $(CXXFLAGS) -o $(TARGET) $(SRC)
.PHONY: clean
clean:
rm -f $(TARGET)
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#include <ctime>
#include <thread>
#include <atomic>
#include <signal.h>
#include <fstream>
#include <sstream>
#include <climits>
#include <mutex>
using namespace std;
// 遗传算法超参数
int DEFAULT_POPULATION_SIZE = 500; // 默认种群大小
double DEFAULT_MUTATION_RATE = 0.1; // 默认变异率
int DEFAULT_GENERATIONS = 500; // 默认进化代数
int MIN_COLORS = 2; // 最小颜色数
int MAX_COLORS = 1000; // 最大颜色数
int NUM_WORKERS = 64; // 并行工作线程数
int TIMEOUT_SECONDS = 500; // 超时时间(秒)
string GRAPH_PATH= "0_0"; // 图的索引
// 图生成参数
int VERTICES = 50; // 图的顶点数
double EDGE_PROBABILITY = 0.1; // 边的生成概率
// 定义全局变量来存储找到的有效解和颜色数
vector<int> globalSolution;
int globalNumColors = INT_MAX;
// 定义互斥锁来保护全局变量
mutex globalMutex;
// 解析命令行参数的函数
void parseCommandLineArguments(int argc, char *argv[])
{
if (argc == 1)
{
return;
}
if (argc == 2)
{
GRAPH_PATH = argv[1];
return;
}
if (argc < 9)
{
cerr << "用法: " << argv[0] << " <GRAPH_PATH> <DEFAULT_POPULATION_SIZE> <DEFAULT_MUTATION_RATE> <DEFAULT_GENERATIONS> <MIN_COLORS> <MAX_COLORS> <NUM_WORKERS> <TIMEOUT_SECONDS>" << endl;
exit(1);
}
GRAPH_PATH = argv[1];
DEFAULT_POPULATION_SIZE = stoi(argv[2]);
DEFAULT_MUTATION_RATE = stod(argv[3]);
DEFAULT_GENERATIONS = stoi(argv[4]);
MIN_COLORS = stoi(argv[5]);
MAX_COLORS = stoi(argv[6]);
NUM_WORKERS = stoi(argv[7]);
TIMEOUT_SECONDS = stoi(argv[8]);
}
// 生成随机图
vector<vector<int>> generateRandomGraph(int vertices, double edgeProbability)
{
random_device rd;
mt19937 gen(rd());
uniform_real_distribution<> dis(0.0, 1.0);
vector<vector<int>> graph(vertices, vector<int>(vertices, 0));
for (int i = 0; i < vertices; i++)
{
for (int j = i + 1; j < vertices; j++)
{
if (dis(gen) < edgeProbability)
{
graph[i][j] = graph[j][i] = 1;
}
}
}
return graph;
}
// 打印图的信息
void printGraphInfo(const vector<vector<int>> &graph)
{
int vertices = graph.size();
int edges = 0;
for (int i = 0; i < vertices; i++)
{
for (int j = i + 1; j < vertices; j++)
{
if (graph[i][j])
edges++;
}
}
// 注释掉或删除不必要的输出
cout << "图的信息:顶点数:" << vertices << ",边数:" << edges << ",平均度:" << (2.0 * edges) / vertices << endl;
}
// 结果结构体
struct ColoringResult
{
bool success;
int numColors;
vector<int> solution;
ColoringResult(bool s = false, int n = 0, vector<int> sol = vector<int>())
: success(s), numColors(n), solution(sol) {}
};
class GraphColoringGA
{
private:
vector<vector<int>> graph; // 邻接矩阵表示的图
int vertices; // 顶点数
int populationSize; // 种群大小
double mutationRate; // 变异率
vector<vector<int>> population; // 种群
// 生成随机数的工具
random_device rd;
mt19937 gen;
uniform_real_distribution<> dis;
uniform_int_distribution<> colorDis;
public:
GraphColoringGA(const vector<vector<int>> &g, int c,
int popSize = DEFAULT_POPULATION_SIZE,
double mutRate = DEFAULT_MUTATION_RATE)
: graph(g), vertices(g.size()),
populationSize(popSize), mutationRate(mutRate),
gen(rd()), dis(0.0, 1.0), colorDis(0, c - 1)
{
initializePopulation();
}
// 初始化种群
void initializePopulation()
{
population.clear();
for (int i = 0; i < populationSize; i++)
{
vector<int> individual(vertices);
for (int j = 0; j < vertices; j++)
{
individual[j] = colorDis(gen);
}
population.push_back(individual);
}
}
// 计算适应度(冲突数的负值)
int calculateFitness(const vector<int> &individual)
{
int conflicts = 0;
for (int i = 0; i < vertices; i++)
{
for (int j = 0; j < vertices; j++)
{
if (graph[i][j] && individual[i] == individual[j])
{
conflicts++;
}
}
}
return -conflicts / 2; // 除以2因为每个冲突被计算了两次
}
// 选择父代
vector<int> selectParent()
{
vector<int> fitnesses;
for (const auto &individual : population)
{
fitnesses.push_back(calculateFitness(individual));
}
// 轮盘赌选择
int totalFitness = 0;
for (int fitness : fitnesses)
{
totalFitness += fitness + vertices; // 加上顶点数使所有适应度为正
}
double r = dis(gen) * totalFitness;
int sum = 0;
for (int i = 0; i < populationSize; i++)
{
sum += fitnesses[i] + vertices;
if (sum > r)
{
return population[i];
}
}
return population.back();
}
// 交叉操作
pair<vector<int>, vector<int>> crossover(const vector<int> &parent1, const vector<int> &parent2)
{
int crossPoint = uniform_int_distribution<>(0, vertices - 1)(gen);
vector<int> child1 = parent1;
vector<int> child2 = parent2;
for (int i = crossPoint; i < vertices; i++)
{
swap(child1[i], child2[i]);
}
return {child1, child2};
}
// 变异操作
void mutate(vector<int> &individual)
{
for (int i = 0; i < vertices; i++)
{
if (dis(gen) < mutationRate)
{
individual[i] = colorDis(gen);
}
}
}
// 添加一个新的方法来检查解的有效性
bool isValidSolution(const vector<int> &solution) const
{
for (int i = 0; i < vertices; i++)
{
for (int j = 0; j < vertices; j++)
{
if (graph[i][j] && solution[i] == solution[j])
{
return false;
}
}
}
return true;
}
// 修改evolve方法,添加超时检查和提前终止标志
vector<int> evolve(int generations, chrono::seconds timeout)
{
auto start_time = chrono::steady_clock::now();
vector<int> bestSolution;
int bestFitness = INT_MIN;
for (int gen = 0; gen < generations; gen++)
{
// 检查是否超时
auto current_time = chrono::steady_clock::now();
if (current_time - start_time > timeout)
{
throw runtime_error("Timeout");
}
vector<vector<int>> newPopulation;
// 精英保留
vector<pair<int, int>> elite;
for (int i = 0; i < populationSize; i++)
{
elite.push_back({calculateFitness(population[i]), i});
}
sort(elite.rbegin(), elite.rend());
newPopulation.push_back(population[elite[0].second]);
// 生成新一代
while (newPopulation.size() < populationSize)
{
vector<int> parent1 = selectParent();
vector<int> parent2 = selectParent();
auto [child1, child2] = crossover(parent1, parent2);
mutate(child1);
mutate(child2);
newPopulation.push_back(child1);
if (newPopulation.size() < populationSize)
{
newPopulation.push_back(child2);
}
}
population = newPopulation;
// 更新最佳解
int currentBestFitness = calculateFitness(population[elite[0].second]);
if (currentBestFitness > bestFitness)
{
bestFitness = currentBestFitness;
bestSolution = population[elite[0].second];
// 如果找到完美解(没有冲突),提前结束
if (bestFitness == 0)
{
break;
}
}
}
return bestSolution;
}
};
// 工作线程函数
void workerThread(const vector<vector<int>> &graph, int numColors, atomic<bool> &foundSolution)
{
try
{
GraphColoringGA ga(graph, numColors);
vector<int> solution = ga.evolve(DEFAULT_GENERATIONS, chrono::seconds(TIMEOUT_SECONDS));
lock_guard<mutex> lock(globalMutex);
if (ga.isValidSolution(solution) && numColors < globalNumColors)
{
foundSolution = true;
globalSolution = solution;
globalNumColors = numColors;
// 在workerThread函数中简化输出
cout << "找到解!使用 " << numColors << " 种颜色,方案:[";
for (int i = 0; i < solution.size(); i++)
{
cout << solution[i];
if (i < solution.size() - 1)
cout << ",";
}
cout << "]" << endl;
}
}
catch (const runtime_error &e)
{
}
}
// 将结果写入CSV文件
void writeResultToCsv(const string &filename, const vector<int> &solution)
{
ofstream outputFile(filename);
for (int color : solution)
{
outputFile << color << ",";
}
outputFile << "\n";
}
// 添加从CSV文件读取图的函数
vector<vector<int>> readGraphFromCsv(const string &filename)
{
ifstream inputFile(filename);
string line;
vector<vector<int>> graph;
while (getline(inputFile, line))
{
vector<int> row;
stringstream ss(line);
string value;
while (getline(ss, value, ','))
{
row.push_back(stoi(value));
}
graph.push_back(row);
}
return graph;
}
// 修改main函数以使用解析后的参数
int main(int argc, char *argv[])
{
// 解析命令行参数
parseCommandLineArguments(argc, argv);
// 设置信号处理
signal(SIGINT, [](int)
{
cout << "\n接收到中断信号,正在退出..." << endl;
exit(0); });
// 从CSV文件读取图
vector<vector<int>> graph = readGraphFromCsv("Optimize_HN/graph_" + GRAPH_PATH + ".csv");
cout << "图的索引: " << GRAPH_PATH << endl;
printGraphInfo(graph);
atomic<bool> foundSolution(false);
int currentColor = MIN_COLORS;
// cout << "\n开始并行寻找最小着色方案..." << endl;
// 按批次处理
while (currentColor <= MAX_COLORS && !foundSolution)
{
vector<thread> currentBatch;
int batchEnd = min(currentColor + NUM_WORKERS, MAX_COLORS + 1);
// 启动当前批次的所有线程
for (int color = currentColor; color < batchEnd && !foundSolution; ++color)
{
// 在main函数中去除启动线程和未找到解的输出
// cout << "启动线程尝试 " << color << " 种颜色" << endl;
currentBatch.emplace_back(workerThread, ref(graph), color, ref(foundSolution));
}
// 等待当前批次所有线程完成
for (auto &thread : currentBatch)
{
thread.join();
}
if (foundSolution)
{
// 将结果写入CSV文件
writeResultToCsv("Optimize_HN/color_" + GRAPH_PATH + ".csv", globalSolution);
break;
}
currentColor = batchEnd;
}
if (!foundSolution)
{
cout << "\n在给定的颜色范围和时间限制内未找到有效的着色方案。" << endl;
}
return 0;
}
# def mapping_weights(weights,value_range):
# new_weights=np.full_like(weights,-1,dtype=int)
# for i in range(len(value_range)):
# new_weights[abs(weights-value_range[i])<=0.01]=i
# return new_weights
# def has_intersection(node1, node2):
# # 使用集合操作加速交集判断
# return 1 if set(node1) & set(node2) else 0
# def generate_color_graph(L, W, value_range,matrix):
# matrix=matrix.transpose(1,0)
# node=[[[] for _ in range(W)] for _ in value_range]
# graph=[[[0 for _ in range(W)] for _ in range(W)] for _ in value_range]
# # 预先生成集合缓存
# node_sets = [[set() for _ in range(W)] for _ in value_range]
# for i in range(W):
# for j in range(L):
# val = matrix[i][j]
# if val!=-1:
# node[val][i].append(j)
# node_sets[val][i].add(j)
# 优化三层循环结构,使用集合操作
# for i in range(len(value_range)):
# for j in range(W):
# current_set = node_sets[i][j]
# for k in range(j, W): # 利用对称性减少一半计算量
# if current_set and node_sets[i][k]:
# has = 1
# graph[i][j][k] = has
# graph[i][k][j] = has # 对称位置赋值
# for i in range(len(value_range)):
# for j in range(W):
# for k in range(W):
# graph[i][j][k]=has_intersection(node[i][j],node[i][k])
# return node,graph
\ No newline at end of file
import enum
import string
class VAR_TYPE(enum.Enum):
WIRE = "wire"
REG = "reg"
GENVAR = "genvar"
INTERGER = "integer"
DEFAULT = "default"
class BaseBlock:
_current_instance_stack = []
def __init__(self) -> None:
self.body = []
def __enter__(self):
BaseBlock._current_instance_stack.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if len(BaseBlock._current_instance_stack) > 1:
father_block = BaseBlock._current_instance_stack[-2]
father_block.add_block(self)
BaseBlock._current_instance_stack.pop()
if exc_type:
print(f"An exception of type {exc_type} occurred.")
print(f"Exception value: {exc_val}")
print(f"Traceback: {exc_tb}")
def add_body(self, line: str):
self.body.append(line)
def generate(self) -> str:
return "\n".join(self.body)
def add_block(self, block):
self.add_body(block.generate())
class ModuleBlock(BaseBlock):
def __init__(self, module_name: str):
super().__init__()
self.module_name = module_name
self.parameters = []
self.inputs = []
self.outputs = []
self.inouts = []
def generate(self) -> str:
module_text = f"module {self.module_name} #(\n"
params = ",\n".join(self.parameters)
inouts = ",\n".join(self.inputs + self.outputs + self.inouts)
bodys = "\n".join(self.body)
# 写入 parameter
module_text += params
module_text += "\n) (\n"
# 写入 inout
module_text += inouts
module_text += "\n);\n"
# 写入 body
module_text += bodys
module_text += "\nendmodule"
return module_text
class AlwaysBlock(BaseBlock):
def __init__(self, sensitivity: str = "posedge clk") -> None:
super().__init__()
self.sensitivity = sensitivity
def generate(self) -> str:
generate_lines = []
generate_lines.append(f"always @({self.sensitivity}) begin")
generate_lines.extend(self.body)
generate_lines.append("end")
return "\n".join(generate_lines)
class IfBlock(BaseBlock):
def __init__(self, condition: str = "rstn") -> None:
super().__init__()
self.condition = condition
def generate(self) -> str:
generate_lines = []
generate_lines.append(f"if ({self.condition}) begin")
generate_lines.extend(self.body)
generate_lines.append("end")
return "\n".join(generate_lines)
class ForBlock(BaseBlock):
def __init__(
self, initial: str, condition: str, update: str, tag: str = ""
) -> None:
super().__init__()
self.initial = initial
self.update = update
self.condition = condition
self.tag = tag
def generate(self) -> str:
generate_lines = []
if self.tag:
generate_lines.append(
f"for ({self.initial}; {self.condition}; {self.update}) begin : {self.tag}"
)
else:
generate_lines.append(
f"for ({self.initial}; {self.condition}; {self.update}) begin"
)
generate_lines.extend(self.body)
generate_lines.append("end")
return "\n".join(generate_lines)
class ElseBlock(BaseBlock):
def __init__(self) -> None:
super().__init__()
def generate(self) -> str:
generate_lines = []
generate_lines.append("else begin")
generate_lines.extend(self.body)
generate_lines.append("end")
return "\n".join(generate_lines)
class GenerateBlock(BaseBlock):
def __init__(self) -> None:
super().__init__()
def generate(self) -> str:
generate_lines = []
generate_lines.append("generate")
generate_lines.extend(self.body)
generate_lines.append("endgenerate")
return "\n".join(generate_lines)
class VerilogGenerator(BaseBlock):
def __init__(self) -> None:
super().__init__()
class _VerilogSentence:
@staticmethod
def parameter_sentence(name: str, value: str) -> str:
return f"parameter {name} = {value}"
@staticmethod
def input_sentence(
var_type: VAR_TYPE = VAR_TYPE.DEFAULT,
name: str = "",
width: str = "1",
height: str = "1",
) -> str:
input_template = string.Template("input $var_type$width$name$height")
if var_type == VAR_TYPE.DEFAULT:
var_type = ""
else:
var_type = f"{var_type} "
if width == "1":
width = ""
else:
width = f"[{width} - 1 : 0] "
if height == "1":
height = ""
else:
height = f"[{height} - 1 : 0]"
input_sentence = input_template.substitute(
var_type=var_type, width=width, name=name, height=height
)
return input_sentence
@staticmethod
def output_sentence(
var_type: VAR_TYPE, name: str, width: str = "1", height: str = "1"
):
output_template = string.Template("output $var_type$width$name$height")
if var_type == VAR_TYPE.DEFAULT:
var_type = ""
else:
var_type = f"{var_type} "
if width == "1":
width = ""
else:
width = f"[{width} - 1 : 0] "
if height == "1":
height = ""
else:
height = f"[{height} - 1 : 0]"
output_sentence = output_template.substitute(
var_type=var_type, width=width, name=name, height=height
)
return output_sentence
@staticmethod
def inout_sentence(
var_type: VAR_TYPE, name: str, width: str = "1", height: str = "1"
):
inout_template = string.Template("inout $var_type$width$name$height")
if var_type == VAR_TYPE.DEFAULT:
var_type = ""
else:
var_type = f"{var_type} "
if width == "1":
width = ""
else:
width = f"[{width} - 1 : 0] "
if height == "1":
height = ""
else:
height = f"[{height} - 1 : 0]"
inout_sentence = inout_template.substitute(
var_type=var_type, width=width, name=name, height=height
)
return inout_sentence
@staticmethod
def assign_sentence(
out_var: str, out_indices: list, in_var: str, in_indexs: list
) -> str:
assign_template = string.Template(
"assign $out_var$out_indices = $in_var$in_indices;"
)
out_indices_str = "".join([f"[{index}]" for index in out_indices])
in_indices_str = "".join([f"[{index}]" for index in in_indexs])
assign_sentence = assign_template.substitute(
out_var=out_var,
out_indices=out_indices_str,
in_var=in_var,
in_indices=in_indices_str,
)
return assign_sentence
@staticmethod
def instance_sentence(
module_name: str, instance_name: str, parameters: dict, ports: dict
) -> str:
instance_lines = [f"{module_name} "]
if parameters:
instance_lines = [f"{module_name} #("]
for param_name, param_value in parameters.items():
instance_lines.append(f" .{param_name}({param_value}),")
if instance_lines[-1].endswith(","):
instance_lines[-1] = instance_lines[-1][:-1]
instance_lines.append(")")
instance_lines.append(f"{instance_name} (")
for port_name, signal_name in ports.items():
instance_lines.append(f" .{port_name}({signal_name}),")
if instance_lines[-1].endswith(","):
instance_lines[-1] = instance_lines[-1][:-1]
instance_lines.append(");")
return "\n".join(instance_lines)
@staticmethod
def var_sentence(
var_type: VAR_TYPE, name: str, width: str = "1", height: str = "1"
) -> str:
var_template = string.Template("$var_type $width$name$height;")
if width == "1":
width = ""
else:
width = f"[{width} - 1 : 0] "
if height == "1":
height = ""
else:
height = f"[{height} - 1 : 0]"
var_sentence = var_template.substitute(
var_type=var_type.value, width=width, name=name, height=height
)
return var_sentence
def _find_father() -> BaseBlock:
return BaseBlock._current_instance_stack[-1]
def add_body(line: str):
father = _find_father()
father.add_body(line)
def add_newline():
add_body("")
def add_parameter(name: str, value: str):
father = _find_father()
try:
assert father.__class__ == ModuleBlock
father.parameters.append(_VerilogSentence.parameter_sentence(name, value))
except Exception as e:
raise e
def add_input(
name: str,
width: str = "1",
height: str = "1",
var_type: VAR_TYPE = VAR_TYPE.DEFAULT,
):
father = _find_father()
try:
assert father.__class__ == ModuleBlock
father.inputs.append(
_VerilogSentence.input_sentence(var_type, name, str(width), str(height))
)
except Exception as e:
raise e
def add_output(
name: str,
width: str = "1",
height: str = "1",
var_type: VAR_TYPE = VAR_TYPE.DEFAULT,
):
father = _find_father()
try:
assert father.__class__ == ModuleBlock
father.outputs.append(
_VerilogSentence.output_sentence(var_type, name, str(width), str(height))
)
except Exception as e:
raise e
def add_inout(
name: str,
width: str = "1",
height: str = "1",
var_type: VAR_TYPE = VAR_TYPE.DEFAULT,
):
father = _find_father()
try:
assert father.__class__ == ModuleBlock
father.inouts.append(
_VerilogSentence.inout_sentence(var_type, name, str(width), str(height))
)
except Exception as e:
raise e
def add_assign(out_var: str, out_indices: list, in_var: str, in_indexs: list):
father = _find_father()
father.add_body(
_VerilogSentence.assign_sentence(out_var, out_indices, in_var, in_indexs)
)
def add_instance(module_name: str, instance_name: str, parameters: dict, ports: dict):
father = _find_father()
father.add_body(
_VerilogSentence.instance_sentence(
module_name, instance_name, parameters, ports
)
)
def add_var(var_type: VAR_TYPE, name: str, width: str = "1", height: str = "1"):
father = _find_father()
father.add_body(_VerilogSentence.var_sentence(var_type, name, width, height))
def add_wire(name: str, width: str = "1", height: str = "1"):
add_var(VAR_TYPE.WIRE, name, width, height)
def add_reg(name: str, width: str = "1", height: str = "1"):
add_var(VAR_TYPE.REG, name, width, height)
def add_genvar(name: str):
add_var(VAR_TYPE.GENVAR, name)
def add_integer(name: str):
add_var(VAR_TYPE.INTERGER, name)
def _test_generate():
with VerilogGenerator() as generator:
with ModuleBlock("Test"):
add_parameter("adsa", "dsa")
with IfBlock("rstn"):
add_assign("a", [], "b", [])
with ElseBlock():
add_assign("a", [], "c", [])
with ModuleBlock("Test11"):
add_parameter("adsa", "dsa")
with IfBlock("rstn"):
add_assign("a", [], "b", [])
with ElseBlock():
add_assign("a", [], "c", [])
print(len(generator.body))
with open("test.sv", "w") as f:
f.write(generator.generate())
if __name__ == "__main__":
_test_generate()
import os
import sys
import pickle
from config import CFG
import generate_ww
import generate_sub_wrapper
import generate_lm
import generate_wt_group
import generate_wrappers
import generate_mid_wrapper
import generate_wallace
import generate_fsm
import generate_hn
def run_generate_verilog():
generate_ww.run()
print("生成WW完成")
generate_sub_wrapper.run()
print("生成Sub_wrapper完成")
generate_hn.run()
print("生成HN完成")
generate_lm.run()
print("生成Layer_mux完成")
generate_wt_group.run()
print("生成WT_group完成")
generate_wrappers.run()
print("生成Wrappers完成")
generate_mid_wrapper.run()
print("生成Mid_wrappers完成")
generate_wallace.run()
print("生成Wallace Tree完成")
generate_fsm.run()
print("生成FSM完成")
def print_menu():
print("\n可用的生成选项:")
print("1. 生成 WW")
print("2. 生成 Sub_wrapper")
print("3. 生成 HN")
print("4. 生成 Layer_mux")
print("5. 生成 WT_group")
print("6. 生成 Wrappers")
print("7. 生成 Mid_wrappers")
print("8. 生成 Wallace Tree")
print("9. 生成 FSM")
print("10. 生成全部")
print("0. 退出")
return input("请选择要生成的模块 (0-10): ")
def run_selected_generate(choice):
if choice == "1":
generate_ww.run()
print("生成WW完成")
elif choice == "2":
generate_sub_wrapper.run()
print("生成Sub_wrapper完成")
elif choice == "3":
generate_hn.run()
print("生成HN完成")
elif choice == "4":
generate_lm.run()
print("生成Layer_mux完成")
elif choice == "5":
generate_wt_group.run()
print("生成WT_group完成")
elif choice == "6":
generate_wrappers.run()
print("生成Wrappers完成")
elif choice == "7":
generate_mid_wrapper.run()
print("生成Mid_wrappers完成")
elif choice == "8":
generate_wallace.run()
print("生成Wallace Tree完成")
elif choice == "9":
generate_fsm.run()
print("生成FSM完成")
elif choice == "10":
run_generate_verilog()
if __name__ == "__main__":
print("请选择运行模式:")
print(
f"1. 使用 run_weights_batch(批量运行), 当前权重文件:{CFG.run_weights_batch}"
)
print(f"2. 使用 run_weights(单次运行), 当前权重文件:{CFG.run_weights}")
mode = input("请选择 (1/2): ")
if mode == "1":
while True:
choice = print_menu()
if choice == "0":
break
for weights in CFG.run_weights_batch:
print(f"\n正在处理 weights: {weights}")
CFG.run_weights = weights
run_selected_generate(choice)
elif mode == "2":
while True:
choice = print_menu()
if choice == "0":
break
run_selected_generate(choice)
else:
print("无效的选择!")
from setuptools import setup, Extension
import pybind11
ext_modules = [
Extension(
'optimize_HN',
['optimize_HN.cpp'],
include_dirs=[pybind11.get_include()],
language='c++',
extra_compile_args=['-std=c++11'],
extra_link_args=['-static-libstdc++'],
),
]
setup(
name='optimize_HN',
version='0.1',
ext_modules=ext_modules,
)
\ No newline at end of file
# %%
from multiprocessing import Pool
import pickle
import numpy as np
from tqdm import tqdm
from prettytable import PrettyTable
import os
from concurrent.futures import ProcessPoolExecutor
from hllm.config import CFG
from hllm.utils import calculate_WW, find_index
# 返回第i位
def get_bit(num, i):
if i < 0:
return 0
return (num >> i) & 1
# %%
class HN:
def __init__(self, matrix, H, L):
self.matrix = matrix
self.H = H
self.L = L
# def find_index(self, value):
# return np.searchsorted(CFG.value_range, value)
def calculate(self, HN_in: np.ndarray):
HN_out = np.zeros((self.L, len(CFG.value_range)), dtype=int)
ans = np.zeros(self.L)
matrix_masked = self.matrix * HN_in
for i, layer in enumerate(matrix_masked):
for j, value in enumerate(layer):
if abs(value) <= 1e-3:
continue
index = find_index(CFG.value_range, value)
HN_out[i][index] += 1
ans[i] += value
# indices=list(map(self.find_index,layer))
# np.add.at(HN_out[i],indices,1)
return HN_out, ans
class HN_GROUP:
def __init__(self, weights: np.ndarray):
self.VN, self.L, self.H = weights.shape
print(weights.shape)
self.HN_GROUP = [HN(matrix, self.H, self.L) for matrix in weights]
print("HN_GROUP init done")
def calculate_single(self, hn_group, hn_in):
return hn_group.calculate(hn_in)
def calculate(self, hn_in: np.ndarray):
hn_out = [None] * len(self.HN_GROUP)
ans = [None] * len(self.HN_GROUP)
with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:
futures = [
executor.submit(hn_group.calculate, hn_in) for hn_group in self.HN_GROUP
]
for i, future in enumerate(tqdm(futures)):
hn_out[i], ans[i] = future.result()
ans = np.array(ans)
return ans
def run(config: CFG):
activation_name = os.path.join(config.verify_dir, "activation.pkl")
result_name = os.path.join(config.verify_dir, "result.txt")
with open(activation_name, "rb") as f:
hn_in = pickle.load(f)
weights_path = os.path.join(config.verify_dir, config.verify_weights)
with open(weights_path, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
hn_group = HN_GROUP(matrixs)
hn_in = get_bit(hn_in, 7)
hn_out = hn_group.calculate(hn_in)
print(hn_out)
import hllm.optimized.turbo_optimize_hn as turbo_optimize_hn
import numpy as np
def test_graph_coloring():
# 创建一个100x100的随机稀疏邻接矩阵
np.random.seed(41) # 为了结果可重复,设置随机种子
size = 1536
weight = np.random.randint(-1, 2, size=(size, size))
adj_matrix = np.random.randint(0, 2, size=(size, size))
adj_matrix = np.triu(adj_matrix, 1) # 只保留上三角部分
adj_matrix += adj_matrix.T # 对称化矩阵
print("turbo_optimize_hn 模块位置:", turbo_optimize_hn.__file__)
# 调用图着色算法
colors = turbo_optimize_hn.greedy_coloring(adj_matrix, weight)
# 打印结果
# print("节点颜色分配结果:", colors)
print(max(colors))
# 验证结果是否有效
n = len(adj_matrix)
for i in range(n):
for j in range(n):
if adj_matrix[i][j] == 1:
# 相邻节点不应该有相同的颜色
assert (
colors[i] != colors[j] and colors[i] != -1 and colors[j] != -1
), f"相邻节点 {i} 和 {j} 具有相同的颜色!"
print("测试通过!所有相邻节点都有不同的颜色")
if __name__ == "__main__":
test_graph_coloring()
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pyrilog import VerilogGenerator,ModuleBlock,add_parameter,add_input,add_output,add_assign\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"class CFG:\n",
" path_dir = \"HN\"\n",
" value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]\n",
" WW = [8] * len(value_range)\n",
"\n",
"\n",
"def calculate_WW(matrix: np.array, value_range):\n",
" WW = [0] * len(value_range)\n",
" for i in range(len(value_range)):\n",
" WW[i] = max([len([x for x in row if abs(x-value_range[i])<=0.01]) for row in matrix])\n",
" return WW\n",
"\n",
"def find_index(arr, target, epsilon=1e-3):\n",
" arr = np.array(arr) # 转换为numpy数组\n",
" diff = np.abs(arr - target) # 计算差值数组\n",
" min_diff = np.min(diff) # 找到最小的差值\n",
" if min_diff < epsilon: # 如果最小差值在允许的误差范围内\n",
" return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引\n",
" raise ValueError('No match found') # 如果没有找到匹配项,则引发异常"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def generate_verilog_code(\n",
" matrix,\n",
" HN_id=0,\n",
" H=16,\n",
" L=5,\n",
" value_range=[-1, 1],\n",
" WW=[8,8],\n",
"):\n",
" with VerilogGenerator() as generator:\n",
" with ModuleBlock(f\"HN_{HN_id}\"):\n",
" add_parameter(\"H\", H)\n",
" add_parameter(\"L\", L)\n",
" for i in range(len(value_range)):\n",
" add_parameter(f\"WW_{i}\", WW[i])\n",
" add_input(name=\"HN_in\", width=\"H\")\n",
" for i in range(len(value_range)):\n",
" add_output(\n",
" name=f\"HN_out_{i}\",\n",
" width=f\"WW_{i}\",\n",
" height=\"L\",\n",
" )\n",
" # 内部连线\n",
" for i, layer in enumerate(matrix):\n",
" weight_cnt = [0] * len(value_range)\n",
" for j, weight in enumerate(layer):\n",
" if abs(weight)<1e-3:\n",
" continue\n",
" try:\n",
" index=find_index(value_range, weight)\n",
" except ValueError:\n",
" print(f\"weight {weight} not found\")\n",
" continue\n",
" add_assign(\n",
" f\"HN_out_{index}\",\n",
" [i, weight_cnt[index]],\n",
" \"HN_in\",\n",
" [j],\n",
" )\n",
" weight_cnt[index] += 1\n",
" for i in range(len(weight_cnt)):\n",
" while weight_cnt[i] < WW[i]:\n",
" add_assign(f\"HN_out_{i}\", [i, weight_cnt[i]], \"0\", [])\n",
" weight_cnt[i] += 1\n",
" return generator.generate()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1536 [00:00<?, ?it/s]\n"
]
},
{
"ename": "NameError",
"evalue": "name 'result' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mBrokenProcessPool\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[24], line 36\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 35\u001b[0m \u001b[38;5;66;03m# print(1)\u001b[39;00m\n\u001b[1;32m---> 36\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfuture\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 37\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[1;32m~\\AppData\\Roaming\\uv\\python\\cpython-3.11.10-windows-x86_64-none\\Lib\\concurrent\\futures\\_base.py:449\u001b[0m, in \u001b[0;36mFuture.result\u001b[1;34m(self, timeout)\u001b[0m\n\u001b[0;32m 448\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;241m==\u001b[39m FINISHED:\n\u001b[1;32m--> 449\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 451\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition\u001b[38;5;241m.\u001b[39mwait(timeout)\n",
"File \u001b[1;32m~\\AppData\\Roaming\\uv\\python\\cpython-3.11.10-windows-x86_64-none\\Lib\\concurrent\\futures\\_base.py:401\u001b[0m, in \u001b[0;36mFuture.__get_result\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 400\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 401\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception\n\u001b[0;32m 402\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 403\u001b[0m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n",
"\u001b[1;31mBrokenProcessPool\u001b[0m: A process in the process pool was terminated abruptly while the future was running or pending.",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[24], line 38\u001b[0m\n\u001b[0;32m 36\u001b[0m result \u001b[38;5;241m=\u001b[39m future\u001b[38;5;241m.\u001b[39mresult()\n\u001b[0;32m 37\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m---> 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGenerating \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mresult\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m failed with an error: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[1;31mNameError\u001b[0m: name 'result' is not defined"
]
}
],
"source": [
"def process_task(i, matrix, H, L):\n",
" try:\n",
" WW = calculate_WW(matrix, CFG.value_range)\n",
" file_name = os.path.join(CFG.path_dir, f\"HN_{i}.sv\")\n",
" with open(file_name, \"w\") as f:\n",
" f.write(\n",
" generate_verilog_code(\n",
" matrix,\n",
" HN_id=i,\n",
" H=H,\n",
" L=L,\n",
" value_range=CFG.value_range,\n",
" WW=WW,\n",
" )\n",
" )\n",
" return i # 返回任务ID以显示进度\n",
" except Exception as e:\n",
" print(f\"Generating {i} failed with an error: {e}\")\n",
" return None\n",
"\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" os.makedirs(CFG.path_dir, exist_ok=True)\n",
" with open(r\"C:\\Users\\night\\Documents\\Codes\\H-LLM\\weights\\q_proj.pkl\", \"rb\") as f:\n",
" matrixs = pickle.load(f)\n",
" matrixs = np.transpose(matrixs, (1, 0, 2))\n",
" VN, L, H = matrixs.shape\n",
" with ProcessPoolExecutor(max_workers=8) as executor:\n",
" futures = [\n",
" executor.submit(process_task, i, matrixs[i], H, L) for i in range(VN)\n",
" ]\n",
" for future in tqdm(as_completed(futures), total=VN):\n",
" try:\n",
" # print(1)\n",
" result = future.result()\n",
" except Exception as e:\n",
" print(f\"Generating {result} failed with an error: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pyrilog import (\n",
" VerilogGenerator,\n",
" ModuleBlock,\n",
" GenerateBlock,\n",
" ForBlock,\n",
" add_parameter,\n",
" add_input,\n",
" add_output,\n",
" add_genvar,\n",
" add_assign,\n",
" add_wire,\n",
" add_body,\n",
" add_instance,\n",
" add_newline,\n",
")\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class CFG:\n",
" path_dir = \"Layer_mux\"\n",
" weights_dir = \"../001-H-LLM/weights\"\n",
" num_workers = 16\n",
" value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]\n",
"\n",
"\n",
"def calculate_WW(matrix: np.array, value_range):\n",
" WW = [0] * len(value_range)\n",
" for i in range(len(value_range)):\n",
" WW[i] = max(\n",
" [len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]\n",
" )\n",
" return WW\n",
"\n",
"\n",
"def find_index(arr, target, epsilon=1e-3):\n",
" arr = np.array(arr) # 转换为numpy数组\n",
" diff = np.abs(arr - target) # 计算差值数组\n",
" min_diff = np.min(diff) # 找到最小的差值\n",
" if min_diff < epsilon: # 如果最小差值在允许的误差范围内\n",
" return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引\n",
" raise ValueError(\"No match found\") # 如果没有找到匹配项,则引发异常"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def generate_module(\n",
" matrix,\n",
" module_name_suffix=\"\",\n",
" H=16,\n",
" L=5,\n",
" value_range=[-1, 1],\n",
" WW=[8, 8],\n",
"):\n",
" with ModuleBlock(f\"{CFG.path_dir}_{module_name_suffix}\") as module:\n",
" # 参数\n",
" add_parameter(\"L\", L)\n",
" for i in range(len(value_range)):\n",
" add_parameter(f\"WW_{i}\", WW[i])\n",
" # 输入输出\n",
" add_input(\"LM_sel\", \"L\")\n",
" for i in range(len(value_range)):\n",
" add_input(\n",
" f\"LM_in_{i}\",\n",
" f\"WW_{i}\",\n",
" \"L\"\n",
" )\n",
" add_output(\n",
" f\"LM_out_{i}\",\n",
" f\"WW_{i}\"\n",
" )\n",
" # 内部连线\n",
" for i in range(len(value_range)):\n",
" add_wire(\n",
" name=f\"LM_in_{i}_masked\",\n",
" width=f\"WW_{i}\",\n",
" height=\"L\"\n",
" )\n",
" add_wire(\n",
" name=f\"LM_in_{i}_masked_T\",\n",
" width=\"L\",\n",
" height=f\"WW_{i}\"\n",
" )\n",
" add_newline()\n",
" # LM_select_loop\n",
" add_genvar(\"i\")\n",
" with GenerateBlock():\n",
" with ForBlock(\"i=0\",\"i<L\",\"i=i+1\",\"LM_select_loop\"):\n",
" for j in range(len(value_range)):\n",
" add_body(\n",
" f\"assign LM_in_{j}_masked[i]=LM_in_{j}[i] & {{WW_{j}{{LM_sel[i]}}}}\",\n",
" )\n",
" add_newline() \n",
" # LM_transpose_loop_out\n",
" add_genvar(\"j\")\n",
" add_genvar(\"k\")\n",
" with GenerateBlock():\n",
" with ForBlock(\"k=0\",\"k<L\",\"k=k+1\",\"LM_transpose_loop_out\"):\n",
" for i in range(len(value_range)):\n",
" with ForBlock(\"j=0\",f\"j<WW_{i}\",\"j=j+1\",f\"LM_transpose_loop_in_{i}\"):\n",
" add_assign(f\"LM_in_{i}_masked_T\",[\"j\",\"k\"],f\"LM_in_{i}_masked\",[\"k\",\"j\"])\n",
" add_newline()\n",
" # LM_reduce_or_loop\n",
" add_genvar(\"m\")\n",
" with GenerateBlock():\n",
" for i in range(len(value_range)):\n",
" with ForBlock(\"m=0\",f\"m<WW_{i}\",\"m=m+1\",f\"LM_reduce_or_loop_{i}\"):\n",
" add_body(f\"assign LM_out_{i}[m] = |LM_in_{i}_masked_T[m]|\")\n",
" \n",
" return module"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing k\n",
"Loading k\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ | 41/512 [00:03<00:45, 10.31it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[4], line 40\u001b[0m\n\u001b[0;32m 38\u001b[0m VN, L, H \u001b[38;5;241m=\u001b[39m matrixs\u001b[38;5;241m.\u001b[39mshape\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(VN)):\n\u001b[1;32m---> 40\u001b[0m \u001b[43mprocess_task\u001b[49m\u001b[43m(\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights_file_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmatrixs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mH\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 41\u001b[0m \u001b[38;5;66;03m# with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:\u001b[39;00m\n\u001b[0;32m 42\u001b[0m \u001b[38;5;66;03m# futures = [\u001b[39;00m\n\u001b[0;32m 43\u001b[0m \u001b[38;5;66;03m# executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[38;5;66;03m# except Exception as e:\u001b[39;00m\n\u001b[0;32m 50\u001b[0m \u001b[38;5;66;03m# print(f\"Generating {result} failed with an error: {e}\")\u001b[39;00m\n",
"Cell \u001b[1;32mIn[4], line 3\u001b[0m, in \u001b[0;36mprocess_task\u001b[1;34m(i, weights_file_name, matrix, H, L)\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprocess_task\u001b[39m(i, weights_file_name, matrix, H, L):\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m----> 3\u001b[0m WW \u001b[38;5;241m=\u001b[39m \u001b[43mcalculate_WW\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmatrix\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCFG\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue_range\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m file_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(CFG\u001b[38;5;241m.\u001b[39mpath_dir, weights_file_name)\n\u001b[0;32m 5\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(file_dir, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"Cell \u001b[1;32mIn[2], line 12\u001b[0m, in \u001b[0;36mcalculate_WW\u001b[1;34m(matrix, value_range)\u001b[0m\n\u001b[0;32m 9\u001b[0m WW \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(value_range)\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(value_range)):\n\u001b[0;32m 11\u001b[0m WW[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\n\u001b[1;32m---> 12\u001b[0m \u001b[43m[\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrow\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mabs\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mvalue_range\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m<\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.01\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrow\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmatrix\u001b[49m\u001b[43m]\u001b[49m\n\u001b[0;32m 13\u001b[0m )\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WW\n",
"Cell \u001b[1;32mIn[2], line 12\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 9\u001b[0m WW \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(value_range)\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(value_range)):\n\u001b[0;32m 11\u001b[0m WW[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\n\u001b[1;32m---> 12\u001b[0m [\u001b[38;5;28mlen\u001b[39m(\u001b[43m[\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrow\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mabs\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mvalue_range\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m<\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.01\u001b[39;49m\u001b[43m]\u001b[49m) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m matrix]\n\u001b[0;32m 13\u001b[0m )\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WW\n",
"Cell \u001b[1;32mIn[2], line 12\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 9\u001b[0m WW \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(value_range)\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(value_range)):\n\u001b[0;32m 11\u001b[0m WW[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\n\u001b[1;32m---> 12\u001b[0m [\u001b[38;5;28mlen\u001b[39m([x \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m row \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mabs\u001b[39m(x \u001b[38;5;241m-\u001b[39m value_range[i]) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.01\u001b[39m]) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m matrix]\n\u001b[0;32m 13\u001b[0m )\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WW\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"def process_task(i, weights_file_name, matrix, H, L):\n",
" try:\n",
" WW = calculate_WW(matrix, CFG.value_range)\n",
" file_dir = os.path.join(CFG.path_dir, weights_file_name)\n",
" os.makedirs(file_dir, exist_ok=True)\n",
" file_name = os.path.join(\n",
" file_dir, f\"{CFG.path_dir}_tp_{weights_file_name}_vc_{i}.sv\"\n",
" )\n",
" with open(file_name, \"w\") as f:\n",
" f.write(\n",
" generate_module(\n",
" matrix,\n",
" module_name_suffix=f\"_tp_{weights_file_name}_vc_{i}\",\n",
" H=H,\n",
" L=L,\n",
" value_range=CFG.value_range,\n",
" WW=WW,\n",
" ).generate()\n",
" )\n",
" return i # 返回任务ID以显示进度\n",
" except Exception as e:\n",
" print(f\"Generating {i} failed with an error: {e}\")\n",
" return None\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" os.makedirs(CFG.path_dir, exist_ok=True)\n",
" for weights_file in os.listdir(CFG.weights_dir):\n",
" if weights_file != \"k.pkl\":\n",
" continue\n",
" weights_path = os.path.join(CFG.weights_dir, weights_file)\n",
" weights_file_name = os.path.splitext(weights_file)[0]\n",
" print(f\"Processing {weights_file_name}\")\n",
" with open(weights_path, \"rb\") as f:\n",
" print(f\"Loading {weights_file_name}\")\n",
" matrixs = pickle.load(f)\n",
" matrixs = np.transpose(matrixs, (1, 0, 2))\n",
" VN, L, H = matrixs.shape\n",
" for i in tqdm(range(VN)):\n",
" process_task(i, weights_file_name, matrixs[i], H, L)\n",
" # with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:\n",
" # futures = [\n",
" # executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)\n",
" # for i in range(VN)\n",
" # ]\n",
" # for future in tqdm(as_completed(futures), total=VN):\n",
" # try:\n",
" # result = future.result()\n",
" # except Exception as e:\n",
" # print(f\"Generating {result} failed with an error: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pyrilog import (\n",
" VerilogGenerator,\n",
" ModuleBlock,\n",
" add_parameter,\n",
" add_input,\n",
" add_output,\n",
" add_assign,\n",
" add_wire,\n",
" add_instance,\n",
")\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class CFG:\n",
" path_dir = \"Sub_wrapper\"\n",
" weights_dir = \"../001-H-LLM/weights\"\n",
" num_workers = 16\n",
" value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]\n",
"\n",
"\n",
"def calculate_WW(matrix: np.array, value_range):\n",
" WW = [0] * len(value_range)\n",
" for i in range(len(value_range)):\n",
" WW[i] = max(\n",
" [len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]\n",
" )\n",
" return WW\n",
"\n",
"\n",
"def find_index(arr, target, epsilon=1e-3):\n",
" arr = np.array(arr) # 转换为numpy数组\n",
" diff = np.abs(arr - target) # 计算差值数组\n",
" min_diff = np.min(diff) # 找到最小的差值\n",
" if min_diff < epsilon: # 如果最小差值在允许的误差范围内\n",
" return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引\n",
" raise ValueError(\"No match found\") # 如果没有找到匹配项,则引发异常"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def generate_module(\n",
" matrix,\n",
" module_name_suffix=\"\",\n",
" H=16,\n",
" L=5,\n",
" value_range=[-1, 1],\n",
" WW=[8, 8],\n",
"):\n",
" with ModuleBlock(f\"{CFG.path_dir}_{module_name_suffix}\") as module:\n",
" # 参数\n",
" add_parameter(\"H\", H)\n",
" add_parameter(\"L\", L)\n",
" for i in range(len(value_range)):\n",
" add_parameter(f\"WW_{i}\", WW[i])\n",
" # 输入输出\n",
" add_input(\"clk\")\n",
" add_input(\"tre_rstn\")\n",
" add_input(\"valid\")\n",
" add_input(\"LM_sel\", \"L\")\n",
" add_input(\"SW_in\",\"H\")\n",
" for i in range(len(value_range)):\n",
" add_output(\n",
" name=f\"WT_{i}_out_S\",\n",
" )\n",
" add_output(\n",
" name=f\"WT_{i}_out_C\",\n",
" )\n",
" # 内部连线\n",
" for i in range(len(value_range)):\n",
" add_wire(\n",
" name=f\"HN_out_{i}\",\n",
" width=f\"WW_{i}\",\n",
" height=\"L\"\n",
" )\n",
" add_wire(\n",
" name=f\"LM_out_{i}\",\n",
" width=f\"WW_{i}\",\n",
" )\n",
" \n",
" # 实例化HN\n",
" hn_params = {\n",
" \"H\": H,\n",
" \"L\": L,\n",
" }\n",
" for i in range(len(value_range)):\n",
" hn_params[f\"WW_{i}\"] = f\"WW_{i}\"\n",
" hn_ports = {\n",
" \"HN_in\": \"SW_in\",\n",
" }\n",
" for i in range(len(value_range)):\n",
" hn_ports[f\"HN_out_{i}\"] = f\"HN_out_{i}\"\n",
" add_instance(\"HN\"+module_name_suffix, \"hn\", hn_params, hn_ports)\n",
" \n",
" # 实例化LM\n",
" lm_params = {\n",
" \"L\": L,\n",
" }\n",
" for i in range(len(value_range)):\n",
" lm_params[f\"WW_{i}\"] = f\"WW_{i}\"\n",
" lm_ports = {\n",
" \"LM_sel\": \"LM_sel\",\n",
" }\n",
" for i in range(len(value_range)):\n",
" lm_ports[f\"LM_in_{i}\"] = f\"HN_out_{i}\"\n",
" lm_ports[f\"LM_out_{i}\"] = f\"LM_out_{i}\"\n",
" add_instance(\"Layer_mux\"+module_name_suffix, \"layer_mux\", lm_params, lm_ports)\n",
"\n",
" # 实例化WT\n",
" wt_params = {}\n",
" for i in range(len(value_range)):\n",
" wt_params[f\"WW_{i}\"] = f\"WW_{i}\"\n",
" wt_ports = {\"clk\":\"clk\",\"tre_rstn\":\"tre_rstn\",\"valid\":\"valid\"}\n",
" for i in range(len(value_range)):\n",
" wt_ports[f\"WT_{i}_in\"] = f\"LM_out_{i}\"\n",
" wt_ports[f\"WT_{i}_out_S\"] = f\"WT_{i}_out_S\"\n",
" wt_ports[f\"WT_{i}_out_C\"] = f\"WT_{i}_out_C\"\n",
" add_instance(\"WT_group\"+module_name_suffix, \"wt_group\", wt_params, wt_ports)\n",
" return module"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def process_task(i, weights_file_name, matrix, H, L):\n",
" try:\n",
" WW = calculate_WW(matrix, CFG.value_range)\n",
" file_dir = os.path.join(CFG.path_dir, weights_file_name)\n",
" os.makedirs(file_dir, exist_ok=True)\n",
" file_name = os.path.join(\n",
" file_dir, f\"{CFG.path_dir}_tp_{weights_file_name}_vc_{i}.sv\"\n",
" )\n",
" with open(file_name, \"w\") as f:\n",
" f.write(\n",
" generate_module(\n",
" matrix,\n",
" module_name_suffix=f\"_tp_{weights_file_name}_vc_{i}\",\n",
" H=H,\n",
" L=L,\n",
" value_range=CFG.value_range,\n",
" WW=WW,\n",
" ).generate()\n",
" )\n",
" return i # 返回任务ID以显示进度\n",
" except Exception as e:\n",
" print(f\"Generating {i} failed with an error: {e}\")\n",
" return None\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" os.makedirs(CFG.path_dir, exist_ok=True)\n",
" for weights_file in os.listdir(CFG.weights_dir):\n",
" if weights_file != \"k.pkl\":\n",
" continue\n",
" weights_path = os.path.join(CFG.weights_dir, weights_file)\n",
" weights_file_name = os.path.splitext(weights_file)[0]\n",
" print(f\"Processing {weights_file_name}\")\n",
" with open(weights_path, \"rb\") as f:\n",
" print(f\"Loading {weights_file_name}\")\n",
" matrixs = pickle.load(f)\n",
" matrixs = np.transpose(matrixs, (1, 0, 2))\n",
" VN, L, H = matrixs.shape\n",
" for i in tqdm(range(VN)):\n",
" process_task(i, weights_file_name, matrixs[i], H, L)\n",
" # with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:\n",
" # futures = [\n",
" # executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)\n",
" # for i in range(VN)\n",
" # ]\n",
" # for future in tqdm(as_completed(futures), total=VN):\n",
" # try:\n",
" # result = future.result()\n",
" # except Exception as e:\n",
" # print(f\"Generating {result} failed with an error: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
import numpy as np
import Pyrilog.Pyrilog as pl
import os
class CFG:
path_dir = "TFSM"
H = 16
L = 5
VN = 5
WP = 4
AP = 8
SCW = 11
SCWB = 4
TTW = 16
value_range = [x for x in range(-8, 8) if x != 0]
value_dict = {x: f"{'pos' if x >0 else 'neg'}_{abs(x)}" for x in value_range}
WW = [8] * len(value_range)
def generate_verilog_code(
path_dir,
file_id=0,
H=16,
L=5,
VN=5,
WP=4,
AP=8,
SCW=11,
SCWB=4,
TTW=16,
value_range=[-1, 1],
value_dict={-1: "neg_1", 1: "pos_1"},
WW={-1: 8, 1: 8},
):
file_name = os.path.join(path_dir, f"TFSM_{file_id}.sv")
generator = pl.VerilogGenerator()
module = pl.ModuleBlock(f"TFSM_{file_id}")
# 增加参数
module.add_parameter("H", str(H))
module.add_parameter("L", str(L))
module.add_parameter("VN", str(VN))
module.add_parameter("WP", str(WP))
module.add_parameter("AP", str(AP))
module.add_parameter("SCW", str(SCW))
module.add_parameter("SCWB", str(SCWB))
module.add_parameter("TTW", str(TTW))
for i in value_range:
module.add_parameter(f"WW_{value_dict[i]}", str(WW[i]))
# 增加输入输出
module.add_input("clk")
module.add_input("tree_rstn")
module.add_input("valid")
module.add_input("fsm_rstn")
module.add_input("LM_sel", width="L")
module.add_input("Top_in", width="AP")
module.add_output("WT_result_acc", pl.VAR_TYPE.REG, "TTW", "VN")
module.add_output("result_valid", pl.VAR_TYPE.REG)
module.add_reg("TM_sel", "AP")
module.add_wire("TM_out", "H")
# 实例化 Top_mux
Top_mux_params = {"H": "H", "AP": "AP"}
Top_mux_ports = {"TM_sel": "TM_sel", "TM_in": "TM_in", "TM_out": "TM_out"}
module.add_instance("Top_mux", "top_mux", Top_mux_params, Top_mux_ports)
for i in value_range:
module.add_wire(f"WT_{value_dict[i]}_out_S", height="VN")
module.add_wire(f"WT_{value_dict[i]}_out_C", height="VN")
module.add_reg("tree_rstn")
module.add_reg("mac_rstn")
for i in value_range:
module.add_reg(f"final_S_{value_dict[i]}", "SCW", "VN")
module.add_reg(f"final_C_{value_dict[i]}", "SCW", "VN")
module.add_reg("MAC_in_1", "SCW+2", "VN")
module.add_reg("MAC_in_2", "WP", "VN")
module.add_wire("MAC_out", "TTW+1", "VN")
for i in value_range:
module.add_reg(f"WT_result_{value_dict[i]}", "SCW+2", "VN")
module.add_reg("idx", "SCWB+1")
module.add_reg("state", "3")
module.add_genvar("j")
module.add_integer("i")
with pl.GenerateBlock(module) as generate_block:
with pl.ForBlock(
generate_block, "j=0", "j<VN", "j=j+1", "inst_SW+loop"
) as for_block:
SW_params = {"H": "H", "L": "L"}
for i in value_range:
SW_params[f"WW_{value_dict[i]}"] = f"WW_{value_dict[i]}"
SW_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"valid": "valid",
"LM_sel": "LM_sel",
"SW_in": "TM_out",
}
for i in value_range:
SW_ports[f"WT_out_{value_dict[i]}_S"] = f"WT_{value_dict[i]}_out_S[j]"
SW_ports[f"WT_out_{value_dict[i]}_C"] = f"WT_{value_dict[i]}_out_C[j]"
for_block.add_instance(f"SW_{file_id}", "sub_wrapper", SW_params, SW_ports)
MAC_params = {"W_1": "SCW", "W_2": "WP", "W_0": "TTW"}
MAC_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"MAC_in_1": "MAC_in_1[j]",
"MAC_in_2": "MAC_in_2[j]",
"MAC_out": "MAC_out[j]",
}
for_block.add_instance("MAC", "mac", MAC_params, MAC_ports)
with pl.AlwaysBlock(module, "posedge clk or negedge fsm_rstn") as always_block:
with pl.IfBlock(always_block, "!fsm_rstn") as if_block:
if_block = pl.IfBlock("!fsm_rstn")
if_block.add_body("state <= 0;")
if_block.add_body("idx <= 0;")
if_block.add_body("tree_rstn <= 0;")
if_block.add_body("mac_rstn <= 0;")
if_block.add_body("result_valid <= 0;")
if_block.add_body("TM_sel <= 8'b00000000;")
with pl.ForBlock(if_block, "i=0", "i<VN", "i=i+1") as for_block:
for_block.add_body("MAC_in_1[i] <= 0;")
for_block.add_body("MAC_in_2[i] <= 0;")
with pl.ElseBlock(always_block) as else_block:
with pl.IfBlock(else_block, "state == 0") as if_block:
if_block.add_body("idx <= 0;")
if_block.add_body("tree_rstn <= 0;")
if_block.add_body("result_valid <= 0;")
with pl.ForBlock(if_block, "i=0", "i<VN", "i=i+1") as for_block:
for i in value_range:
for_block.add_body(f"final_S_{value_dict[i]}[i] <= 0;")
for_block.add_body(f"final_C_{value_dict[i]}[i] <= 0;")
for_block.add_body(f"WT_result_{value_dict[i]}[i] <= 0;")
for_block.add_body(f"WT_result_acc[i] <= 0;")
# 写不动了,基本就是一行一行翻译的verilog代码
with pl.IfBlock(if_block, "valid == 1") as if_if_block:
if_if_block.add_body("state <= 1;")
with pl.IfBlock(else_block, "state == 1") as if_block:
with pl.IfBlock(if_block, "valid == 1") as if_if_block:
if_if_block.add_body("tree_rstn <= 1;")
if_if_block.add_body("state <= 2;")
generator.add_module(module)
generator.generate(file_name)
if __name__ == "__main__":
os.makedirs(CFG.path_dir, exist_ok=True)
record_path = os.path.join(CFG.path_dir, "records.txt")
record_f = open(record_path, "w")
record_f.write(f"H={CFG.H}, L={CFG.L}, value_range={CFG.value_range}\n\n")
for i in range(CFG.VN):
generate_verilog_code(
CFG.path_dir,
file_id=i,
H=CFG.H,
L=CFG.L,
value_range=CFG.value_range,
value_dict=CFG.value_dict,
WW=CFG.WW,
)
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pyrilog import (\n",
" VerilogGenerator,\n",
" ModuleBlock,\n",
" GenerateBlock,\n",
" ForBlock,\n",
" add_parameter,\n",
" add_input,\n",
" add_output,\n",
" add_genvar,\n",
" add_assign,\n",
" add_wire,\n",
" add_body,\n",
" add_instance,\n",
" add_newline,\n",
")\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"import os\n",
"from tqdm import tqdm\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class CFG:\n",
" path_dir = \"WT_group\"\n",
" weights_dir = \"../001-H-LLM/weights\"\n",
" num_workers = 16\n",
" value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]\n",
"\n",
"\n",
"def calculate_WW(matrix: np.array, value_range):\n",
" WW = [0] * len(value_range)\n",
" for i in range(len(value_range)):\n",
" WW[i] = max(\n",
" [len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]\n",
" )\n",
" return WW\n",
"\n",
"\n",
"def find_index(arr, target, epsilon=1e-3):\n",
" arr = np.array(arr) # 转换为numpy数组\n",
" diff = np.abs(arr - target) # 计算差值数组\n",
" min_diff = np.min(diff) # 找到最小的差值\n",
" if min_diff < epsilon: # 如果最小差值在允许的误差范围内\n",
" return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引\n",
" raise ValueError(\"No match found\") # 如果没有找到匹配项,则引发异常"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def generate_module(\n",
" matrix,\n",
" module_name_suffix=\"\",\n",
" H=16,\n",
" L=5,\n",
" value_range=[-1, 1],\n",
" WW=[8, 8],\n",
"):\n",
" with ModuleBlock(f\"{CFG.path_dir}_{module_name_suffix}\") as module:\n",
" # 参数)\n",
" for i in range(len(value_range)):\n",
" add_parameter(f\"WW_{i}\", WW[i])\n",
" # 输入输出\n",
" add_input(\"clk\")\n",
" add_input(\"tree_rstn\")\n",
" add_input(\"valid\")\n",
" for i in range(len(value_range)):\n",
" add_input(f\"WT_in_{i}\", f\"WW_{i}\")\n",
" add_output(\n",
" f\"WT_{i}_out_S\",\n",
" )\n",
" add_output(\n",
" f\"WT_{i}_out_C\",\n",
" )\n",
" # 内部连线华莱士树\n",
" for i in range(len(value_range)):\n",
" wallace_name = f\"SerialWallaceTree{WW[i]}Input\"\n",
" wallace_port = {\n",
" \"clk\": \"clk\",\n",
" \"rstn\": \"rstn\",\n",
" \"valid\": \"valid\",\n",
" \"addends\": f\"WT_{i}_in\",\n",
" \"out_S\": \"WT_{i}_out_S\",\n",
" \"out_C\": \"WT_{i}_out_C\",\n",
" }\n",
" add_instance(wallace_name, f\"serial_wallace_tree_{i}\", {}, wallace_port)\n",
"\n",
" return module"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing k\n",
"Loading k\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 6/512 [00:01<01:46, 4.76it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 40\u001b[0m\n\u001b[1;32m 38\u001b[0m VN, L, H \u001b[38;5;241m=\u001b[39m matrixs\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(VN)):\n\u001b[0;32m---> 40\u001b[0m \u001b[43mprocess_task\u001b[49m\u001b[43m(\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights_file_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmatrixs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mH\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;66;03m# with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# futures = [\u001b[39;00m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m# executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;66;03m# except Exception as e:\u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;66;03m# print(f\"Generating {result} failed with an error: {e}\")\u001b[39;00m\n",
"Cell \u001b[0;32mIn[8], line 3\u001b[0m, in \u001b[0;36mprocess_task\u001b[0;34m(i, weights_file_name, matrix, H, L)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprocess_task\u001b[39m(i, weights_file_name, matrix, H, L):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m----> 3\u001b[0m WW \u001b[38;5;241m=\u001b[39m \u001b[43mcalculate_WW\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmatrix\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCFG\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue_range\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m file_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(CFG\u001b[38;5;241m.\u001b[39mpath_dir, weights_file_name)\n\u001b[1;32m 5\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(file_dir, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m, in \u001b[0;36mcalculate_WW\u001b[0;34m(matrix, value_range)\u001b[0m\n\u001b[1;32m 9\u001b[0m WW \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mlen\u001b[39m(value_range)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(value_range)):\n\u001b[1;32m 11\u001b[0m WW[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\n\u001b[0;32m---> 12\u001b[0m [\u001b[38;5;28mlen\u001b[39m([x \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m row \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mabs\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mvalue_range\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.01\u001b[39m]) \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m matrix]\n\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m WW\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"def process_task(i, weights_file_name, matrix, H, L):\n",
" try:\n",
" WW = calculate_WW(matrix, CFG.value_range)\n",
" file_dir = os.path.join(CFG.path_dir, weights_file_name)\n",
" os.makedirs(file_dir, exist_ok=True)\n",
" file_name = os.path.join(\n",
" file_dir, f\"{CFG.path_dir}_tp_{weights_file_name}_vc_{i}.sv\"\n",
" )\n",
" with open(file_name, \"w\") as f:\n",
" f.write(\n",
" generate_module(\n",
" matrix,\n",
" module_name_suffix=f\"_tp_{weights_file_name}_vc_{i}\",\n",
" H=H,\n",
" L=L,\n",
" value_range=CFG.value_range,\n",
" WW=WW,\n",
" ).generate()\n",
" )\n",
" return i # 返回任务ID以显示进度\n",
" except Exception as e:\n",
" print(f\"Generating {i} failed with an error: {e}\")\n",
" return None\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" os.makedirs(CFG.path_dir, exist_ok=True)\n",
" for weights_file in os.listdir(CFG.weights_dir):\n",
" if weights_file != \"k.pkl\":\n",
" continue\n",
" weights_path = os.path.join(CFG.weights_dir, weights_file)\n",
" weights_file_name = os.path.splitext(weights_file)[0]\n",
" print(f\"Processing {weights_file_name}\")\n",
" with open(weights_path, \"rb\") as f:\n",
" print(f\"Loading {weights_file_name}\")\n",
" matrixs = pickle.load(f)\n",
" matrixs = np.transpose(matrixs, (1, 0, 2))\n",
" VN, L, H = matrixs.shape\n",
" for i in tqdm(range(VN)):\n",
" process_task(i, weights_file_name, matrixs[i], H, L)\n",
" # with ProcessPoolExecutor(max_workers=CFG.num_workers) as executor:\n",
" # futures = [\n",
" # executor.submit(process_task, i, weights_file_name, matrixs[i], H, L)\n",
" # for i in range(VN)\n",
" # ]\n",
" # for future in tqdm(as_completed(futures), total=VN):\n",
" # try:\n",
" # result = future.result()\n",
" # except Exception as e:\n",
" # print(f\"Generating {result} failed with an error: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
import numpy as np
import pickle
import matplotlib.pyplot as plt
length = 3840
file_path = f"weights-fp32-{length}.pkl"
# 载入权重矩阵
with open(file_path, "rb") as f:
weights = pickle.load(f)
# 检验weights的维度
print(f"weight matrix shape: {weights.shape}")
# # new_shape = (2304, 288) # 576组,每组1152个元素
# # weights = weights.reshape(new_shape)
# 确定缩放因子以使用int4范围(-8到7)
max_value = np.max(weights)
min_value = np.min(weights)
# 归一化权重到int4范围
normalized_weights = (weights - min_value) / (max_value - min_value) * 15 - 8
normalized_weights = np.round(normalized_weights)
# 限制值确保其在int4范围内
quantized_weights = np.clip(normalized_weights, -8, 7).T
# 计算每组中各int4值的频数
int4_values = np.arange(-8, 8)
frequency_counts = np.zeros((length, len(int4_values)))
for i in range(length):
frequency_counts[i, :] = np.histogram(
quantized_weights[i], bins=np.arange(-8.5, 8.5)
)[0]
# 计算每个int4取值在所有组中的标准差和极差
std_devs_per_value = np.std(frequency_counts, axis=0)
ranges_per_value = np.ptp(frequency_counts, axis=0)
# 打印结果
print("Standard deviations per int4 value:", std_devs_per_value)
print("Ranges per int4 value:", ranges_per_value)
# 计算每个int4取值在所有组中的最大频次
max_frequencies_per_value = np.max(frequency_counts, axis=0)
# 计算所有取值的最大频次的总和
total_max_frequency_sum = np.sum(max_frequencies_per_value)
# 打印每个取值的最大频次和总和
print("Maximum frequency per int4 value:", max_frequencies_per_value)
print("Sum of maximum frequencies:", total_max_frequency_sum)
# 绘制标准差和极差的图
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.bar(int4_values, std_devs_per_value, color="blue")
plt.title("Standard Deviation of Frequency per int4 Value")
plt.xlabel("int4 Value")
plt.ylabel("Standard Deviation")
plt.subplot(1, 2, 2)
plt.bar(int4_values, ranges_per_value, color="red")
plt.title("Range of Frequency per int4 Value")
plt.xlabel("int4 Value")
plt.ylabel("Range")
plt.tight_layout()
plt.show()
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'utils_quant'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtqdm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mutils_quant\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpickle\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils_quant'"
]
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import safetensors\n",
"import torch\n",
"from tqdm import tqdm\n",
"import utils_quant\n",
"import pickle\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tensors_1536=[]\n",
"tensors_3840=[]\n",
"file_path=\"/lustre/S/huangdi/open_for_out/models/aimo-progress-prize-trained-models/Code-Math-QA-Proof-quant-per-head-fp4-0913/model.safetensors\"\n",
"with safetensors.safe_open(file_path,framework=\"pt\") as f:\n",
" for i,key in enumerate(tqdm(f.keys())):\n",
" # print(key,f.get_tensor(key).shape)\n",
" if i>10:\n",
" break\n",
" tensor=f.get_tensor(key)\n",
" if tensor.ndim==2:\n",
" if len(tensor[0])==1536:\n",
" tensors_1536.extend(tensor.float().tolist())\n",
" else:\n",
" tensors_3840.extend(tensor.float().tolist())\n",
" else:\n",
" if len(tensor)==1536:\n",
" tensors_1536.append(tensor.float().tolist())\n",
" else:\n",
" tensors_3840.append(tensor.float().tolist())\n",
"tensors_1536=np.array(tensors_1536)\n",
"tensors_3840=np.array(tensors_3840)\n",
"# tensors=np.array(tensors,dtype=np.float32)\n",
"# display(tensors_fp32[:5])\n",
"with open(\"weights-fp32-1536-small.pkl\",\"wb\") as f:\n",
" pickle.dump(tensors_1536, f)\n",
"with open(\"weights-fp32-3840-small.pkl\",\"wb\") as f:\n",
" pickle.dump(tensors_3840, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from prettytable import PrettyTable\n",
"from config import CFG\n",
"import os\n",
"from multiprocessing import Pool"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def calculate_WW(matrix: np.array, value_range):\n",
" WW = [0] * len(value_range)\n",
" for i in range(len(value_range)):\n",
" WW[i] = max(\n",
" [len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]\n",
" )\n",
" return WW\n",
"\n",
"def find_index(arr, target, epsilon=1e-3):\n",
" arr = np.array(arr) # 转换为numpy数组\n",
" diff = np.abs(arr - target) # 计算差值数组\n",
" min_diff = np.min(diff) # 找到最小的差值\n",
" if min_diff < epsilon: # 如果最小差值在允许的误差范围内\n",
" return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引\n",
" raise ValueError(\"No match found\") # 如果没有找到匹配项,则引发异常\n",
"\n",
"#返回第i位\n",
"def get_bit(num, i):\n",
" if i < 0:\n",
" return 0\n",
" return (num >> i) & 1"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class HN:\n",
" def __init__(self, matrix,H,L):\n",
" self.matrix = matrix\n",
" self.H=H\n",
" self.L=L\n",
" \n",
" # def find_index(self, value):\n",
" # return np.searchsorted(CFG.value_range, value)\n",
" \n",
" def calculate(self,HN_in:np.ndarray):\n",
" HN_out = np.zeros((self.L, len(CFG.value_range)), dtype=int)\n",
" ans=np.zeros(self.L)\n",
" matrix_masked = self.matrix * HN_in\n",
" for i, layer in enumerate(matrix_masked):\n",
" for j, value in enumerate(layer):\n",
" if abs(value)<=1e-3:\n",
" continue\n",
" index=find_index(CFG.value_range,value)\n",
" HN_out[i][index]+=1\n",
" ans[i]+=value\n",
" # indices=list(map(self.find_index,layer))\n",
" # np.add.at(HN_out[i],indices,1)\n",
" return HN_out,ans\n",
" \n",
"class HN_GROUP:\n",
" def __init__(self,weights:np.ndarray):\n",
" self.VN,self.L,self.H=weights.shape\n",
" print(weights.shape)\n",
" self.HN_GROUP=[HN(matrix,self.H,self.L) for matrix in weights]\n",
" print(\"HN_GROUP init done\")\n",
" \n",
" \n",
" def calculate_single(args):\n",
" hn, hn_in = args\n",
" return hn.calculate(hn_in)\n",
" \n",
" def calculate(self,hn_in:np.ndarray,layer:int):\n",
" hn_out=[None]*self.VN\n",
" ans=[None]*self.VN\n",
" with Pool() as pool:\n",
" args = [(self.HN_GROUP[i], hn_in) for i in range(self.VN)]\n",
" results = list(tqdm(pool.imap(self.calculate_single, args), total=self.VN))\n",
" for i, result in enumerate(results):\n",
" hn_out[i], ans[i] = result\n",
" ans = [x[layer] for x in ans]\n",
" return ans\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(512, 52, 1536)\n",
"HN_GROUP init done\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/512 [00:00<?, ?it/s]"
]
}
],
"source": [
"if CFG.mode == \"test\":\n",
" weights_name = CFG.test_weights\n",
"elif CFG.mode == \"run\":\n",
" weights_name = CFG.run_weigths\n",
"else:\n",
" raise ValueError(\"Invalid mode\")\n",
"activation_name=os.path.join(CFG.activation_dir,\"activation.pkl\")\n",
"result_name=os.path.join(CFG.results_dir,\"result.txt\")\n",
"with open(activation_name, \"rb\") as f:\n",
" hn_in=pickle.load(f)\n",
"weights_path = os.path.join(CFG.weights_dir, weights_name)\n",
"with open(weights_path, \"rb\") as f:\n",
" matrixs = pickle.load(f)\n",
" matrixs = np.transpose(matrixs, (1, 0, 2))\n",
"\n",
"hn_group=HN_GROUP(matrixs)\n",
"\n",
"hn_in=get_bit(hn_in,7)\n",
"\n",
"hn_group.calculate(hn_in,0)\n",
"# for matrix in matrixs:\n",
"# table=PrettyTable()\n",
"# table.field_names=[str(f) for f in CFG.value_range]\n",
"# print(\"--------------------------------\")\n",
"# print(\"输入\")\n",
"# print(hn_in)\n",
"# print(\"权重\")\n",
"# print(matrix)\n",
"# print(\"结果-单独\")\n",
"# hn_out,ans=hn.calculate(hn_in)\n",
"# table.add_rows(hn_out)\n",
"# print(table)\n",
"# print(\"结果-总和\")\n",
"# print(ans)\n",
"# print()\n",
"# break\n",
" \n",
" \n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import os
import sys
import pickle
def update_weights_shape(weights):
weights_file = os.path.join(CFG.weights_dir, weights)
with open(weights_file) as f:
weights = pickle.load(f)
shape = weights.shape
return shape
class CFG:
def __init__(self):
self.mode = "run" # "test" or "run"
self.run_weights_batch = [
"down.pkl",
"up.pkl",
"gate.pkl",
"k.pkl",
"o.pkl",
"v.pkl",
"q.pkl",
]
self.run_weights = "down.pkl" # 用于赋值
self.safetensors = "model.safetensors"
self.weights_dir = "001-H-LLM/qwen"
self.mapped_weights_dir = "001-H-LLM/qwen/mapped_weights"
self.verify_generate_activation_on_exist = False
self.num_workers = 64
self.value_range = [-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6]
self.python_path = sys.executable
self.group_number = 32
self.output_dir = "outputs"
os.makedirs(self.weights_dir, exist_ok=True)
os.makedirs(self.output_dir, exist_ok=True)
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pickle\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from prettytable import PrettyTable\n",
"from copy import deepcopy"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def eda_weights(weights):\n",
" fp4_values = np.array([-6, -4, -3, -2, -1.5, -1, -0.5, 0.5, 1, 1.5, 2, 3, 4, 6])\n",
" weights=weights.reshape(weights.size//weights.shape[-1],weights.shape[-1])\n",
" frequency_counts = np.zeros((len(weights), len(fp4_values)))\n",
" bins = np.concatenate((fp4_values - 0.01, [fp4_values[-1] + 0.01]))\n",
" for i in tqdm(range(len(weights))):\n",
" frequency_counts[i]=np.histogram(weights[i],bins=bins)[0]\n",
" # 计算每个int4取值在所有组中的标准差和极差\n",
" # display(frequency_counts)\n",
" # std_devs_per_value = np.std(frequency_counts, axis=0)\n",
" # ranges_per_value = np.ptp(frequency_counts, axis=0)\n",
" # 打印结果\n",
"\n",
" # print(\"Standard deviations per int4 value:\", std_devs_per_value)\n",
" # print(\"Ranges per int4 value:\", ranges_per_value)\n",
" # 计算每个int4取值在所有组中的最大频次\n",
" max_frequencies_per_value = np.max(frequency_counts, axis=0)\n",
" # 计算所有取值的最大频次的总和\n",
" total_max_frequency_sum = np.sum(max_frequencies_per_value)\n",
" # 打印每个取值的最大频次和总和\n",
" # print(\"Maximum frequency per int4 value:\", max_frequencies_per_value)\n",
" print(\"Sum of maximum frequencies:\", total_max_frequency_sum)\n",
" \n",
" table=PrettyTable()\n",
" table.field_names=[\"type\"]+[str(f) for f in fp4_values]\n",
" # table.add_row([\"std\"]+[str(round(f,3)) for f in std_devs_per_value])\n",
" # table.add_row([\"range\"]+[str(round(f,3)) for f in ranges_per_value])\n",
" table.add_row([\"max\"]+[str(round(f,3)) for f in max_frequencies_per_value])\n",
" print(table)\n",
" # 绘制标准差和极差的图\n",
" # plt.figure(figsize=(12, 6))\n",
" # plt.subplot(1, 2, 1)\n",
" # plt.bar(fp4_values, std_devs_per_value, color=\"blue\")\n",
" # plt.title(\"Standard Deviation of Frequency per int4 Value\")\n",
" # plt.xlabel(\"int4 Value\")\n",
" # plt.ylabel(\"Standard Deviation\")\n",
"\n",
" # plt.subplot(1, 2, 2)\n",
" # plt.bar(fp4_values, ranges_per_value, color=\"red\")\n",
" # plt.title(\"Range of Frequency per int4 Value\")\n",
" # plt.xlabel(\"int4 Value\")\n",
" # plt.ylabel(\"Range\")\n",
"\n",
" # plt.tight_layout()\n",
" # plt.show()\n",
"\n",
"def eda_weights_52(weights):\n",
" fp4_values = np.array([-6, -4, -3, -2, -1.5, -1, -0.5,0.5, 1, 1.5, 2, 3, 4, 6])\n",
" weights=np.transpose(weights,(1,0,2))#512,52,1536\n",
" frequency_counts = np.zeros((len(weights), 52,len(fp4_values)))\n",
" bins = np.concatenate((fp4_values - 0.01, [fp4_values[-1] + 0.01]))\n",
" for i in tqdm(range(len(weights))):\n",
" for j in range(52):\n",
" frequency_counts[i,j]=np.histogram(weights[i,j],bins=bins)[0]\n",
" # 计算每个int4取值在所有组中的标准差和极差 512,14\n",
" frequency_counts=np.max(frequency_counts,axis=1)\n",
" # std_devs_per_value = np.std(frequency_counts, axis=0)\n",
" # ranges_per_value = np.ptp(frequency_counts, axis=0)\n",
" # 打印结果\n",
"\n",
" # print(\"Standard deviations per int4 value:\", std_devs_per_value)\n",
" # print(\"Ranges per int4 value:\", ranges_per_value)\n",
" # 计算每个int4取值在所有组中的最大频次\n",
" mean_frequencies_per_value = np.mean(frequency_counts, axis=0)\n",
" # 计算所有取值的最大频次的总和\n",
" total_mean_frequency_sum = np.sum(mean_frequencies_per_value)\n",
" # 打印每个取值的最大频次和总和\n",
" # print(\"Maximum frequency per int4 value:\", max_frequencies_per_value)\n",
" print(\"Sum of mean frequencies:\", total_mean_frequency_sum)\n",
" \n",
" table=PrettyTable()\n",
"\n",
" table.field_names=[\"type\"]+[str(f) for f in fp4_values]\n",
" # table.add_row([\"std\"]+[str(round(f,3)) for f in std_devs_per_value])\n",
" # table.add_row([\"range\"]+[str(round(f,3)) for f in ranges_per_value])\n",
" table.add_row([\"mean\"]+[str(round(f,3)) for f in mean_frequencies_per_value])\n",
" print(table)\n",
" # 绘制标准差和极差的图\n",
" # plt.figure(figsize=(12, 6))\n",
" # plt.subplot(1, 2, 1)\n",
" # plt.bar(fp4_values, std_devs_per_value, color=\"blue\")\n",
" # plt.title(\"Standard Deviation of Frequency per int4 Value\")\n",
" # plt.xlabel(\"int4 Value\")\n",
" # plt.ylabel(\"Standard Deviation\")\n",
"\n",
" # plt.subplot(1, 2, 2)\n",
" # plt.bar(fp4_values, ranges_per_value, color=\"red\")\n",
" # plt.title(\"Range of Frequency per int4 Value\")\n",
" # plt.xlabel(\"int4 Value\")\n",
" # plt.ylabel(\"Range\")\n",
"\n",
" # plt.tight_layout()\n",
" # plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = Path(\"weights\")\n",
"for file_path in path.rglob('*'):\n",
" if file_path.is_file() and \"proj.pkl\"in str(file_path):\n",
" with open(file_path, \"rb\") as f :\n",
" print(f\"Reading file {file_path}\")\n",
" weights = pickle.load(f)\n",
" print(f\"weight matrix shape: {weights.shape}\")\n",
" eda_weights(deepcopy(weights))\n",
" eda_weights_52(deepcopy(weights))\n",
" print()\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import safetensors\n",
"import torch\n",
"from tqdm import tqdm\n",
"from utils_quant import quant_and_dequant\n",
"import pickle\n",
"from hllm.config import CFG\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 837/837 [04:13<00:00, 3.30it/s]\n",
"100%|██████████| 11/11 [02:13<00:00, 12.17s/it]\n"
]
}
],
"source": [
"weights = {\n",
" \"input_layernorm\": [],\n",
" \"down_proj\": [],\n",
" \"gate_proj\": [],\n",
" \"up_proj\": [],\n",
" \"post_attention_layernorm\": [],\n",
" \"k_proj\": [],\n",
" \"o_proj\": [],\n",
" \"q_proj\": [],\n",
" \"v_proj\": [],\n",
" \"embed_tokens\": [],\n",
" \"model.norm\": [],\n",
"}\n",
"\n",
"\n",
"# 列表中是否有字符串的子串\n",
"def is_substring_in_list(substring, string_list):\n",
" return any(s in substring for s in string_list)\n",
"\n",
"\n",
"ignored_weights = [\n",
" \"embed_tokens.weight\",\n",
" \"post_attention_layernorm.weight\",\n",
" \"activation_quant\",\n",
" \"input_layernorm.weight\",\n",
"]\n",
"\n",
"file_path = \"../001-H-LLM/weights1026/model.safetensors\"\n",
"with safetensors.safe_open(file_path, framework=\"pt\") as f:\n",
" for i, key in enumerate(tqdm(f.keys())):\n",
" if is_substring_in_list(key, ignored_weights):\n",
" continue\n",
" tensor = f.get_tensor(key)\n",
" # print(key,tensor.shape)\n",
" # if i>10:\n",
" # break\n",
" tensor = quant_and_dequant(tensor, 4).tolist()\n",
" for k in weights.keys():\n",
" if k in key:\n",
" weights[k].append(tensor)\n",
"\n",
"for key in tqdm(weights.keys()):\n",
" weights[key] = np.array(weights[key])\n",
" file_path = f\"../001-H-LLM/weights1026/{key}.pkl\"\n",
" with open(file_path, \"wb\") as f:\n",
" pickle.dump(weights[key], f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import os
import numpy as np
import safetensors
import pickle
from tqdm import tqdm
from hllm.eda.utils_quant import quant_and_dequant
from hllm.config import CFG
# %%
name_dict = {
"down_proj.weight": "down",
"gate_proj.weight": "gate",
"up_proj.weight": "up",
"k_proj.weight": "k",
"o_proj.weight": "o",
"q_proj.weight": "q",
"v_proj.weight": "v",
"model.norm.weight": "norm",
}
weights = {v: [] for v in name_dict.values()}
# 列表中是否有字符串的子串
def is_substring_in_list(substring, string_list):
return any(s in substring for s in string_list)
def run(config: CFG):
file_path = os.path.join(config.weights_dir, config.safetensors)
with safetensors.safe_open(file_path, framework="pt") as f:
for i, key in enumerate(tqdm(f.keys())):
if not is_substring_in_list(key, name_dict.keys()):
continue
tensor = f.get_tensor(key)
tensor = quant_and_dequant(tensor, 4)
for k in name_dict.keys():
if k in key:
weights[name_dict[k]].append(tensor.tolist())
for k in weights.keys():
weights[k] = np.array(weights[k])
file_path = os.path.join(config.weights_dir, f"{k}.pkl")
with open(file_path, "wb") as f:
pickle.dump(weights[k], f)
import numpy as np
import os
from hllm.config import CFG
import pickle
from tqdm import tqdm
def mapping_weights(weights, value_range):
new_weights = np.full_like(weights, -1, dtype=int)
for i in range(len(value_range)):
new_weights[abs(weights - value_range[i]) <= 0.01] = i
return new_weights
def run(config: CFG):
print("Start mapping weights")
path_dir = os.path.join(config.mapped_weights_dir)
os.makedirs(path_dir, exist_ok=True)
value_range = config.value_range
for file in os.listdir(config.weights_dir):
if file in config.run_weights_batch:
with open(os.path.join(config.weights_dir, file), "rb") as f:
print(f"Loading {file}")
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
print(VN, L, H)
new_weights = mapping_weights(matrixs, value_range)
new_weights = np.transpose(new_weights, (1, 0, 2))
with open(os.path.join(path_dir, file), "wb") as f:
pickle.dump(new_weights, f)
print("Mapped weights at", path_dir)
import math
import torch
from torch import nn
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
def get_scale_f32(src_amax, dst_max):
S = (src_amax.float()) / dst_max
qscale = 1 / S
dqscale = S
return qscale, dqscale
def round_to_FP4(input):
dst_max = 6.0
emax = 2
emin = 0
p = 2
part = 2 - 2 ** (1 - p)
ab = torch.where(
torch.isinf(input) + torch.isnan(input), torch.ones_like(input) * dst_max, input
)
ab = torch.where(ab > dst_max, torch.ones_like(ab) * dst_max, ab)
ab = torch.where(ab < 2.0 ** (emin) * 2 ** (-p), torch.zeros_like(ab), ab)
E = torch.where(
ab < 2 ** (emin),
torch.ones_like(ab) * (emin),
torch.floor(torch.log2(ab.float())),
)
P = torch.round(ab * 2 ** (-E) * 2 ** (p - 1)) / 2 ** (p - 1)
data = 2**E * P
return data
def quant_and_dequant(data, num_bits):
sign = torch.sign(data)
abs_data = torch.abs(data).float()
amax, index = torch.max(
abs_data, -1, True
) # 这个示例是做的per-channel量化,即对于(M,K)的矩阵,有M个量化参数(M个amax)
qscale, dqscale = get_scale_f32(amax, 6.0)
quant_data = round_to_FP4(abs_data * qscale)
dequant_data = (quant_data * dqscale * sign).to(data.dtype)
return sign * quant_data
return dequant_data
class CLMLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(CLMLinear, self).__init__(*kargs, **kwargs)
"""
RMSNorm is placed outside BitLinear
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
def forward(self, input):
quant_input = (
input + (activation_quant(input, self.input_bits) - input).detach()
)
quant_weight = (
self.weight
+ (quant_and_dequant(self.weight, self.weight_bits) - self.weight).detach()
)
out = nn.functional.linear(quant_input, quant_weight)
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)
return out
import os
from hllm.config import CFG
class TCL_dependency:
def __init__(self, config: CFG, name: str, file_name: str, weights_file_name: str,use_weights: bool = True):
self.config = config
self.name = name
self.file_name = file_name
self.weights_file_name = weights_file_name
self.use_weights = use_weights
def __str__(self):
if self.use_weights:
path = os.path.join(
self.config.output_dir,
self.name,
self.weights_file_name,
self.file_name,
)
else:
path = os.path.join(
self.config.output_dir,
self.name,
self.file_name,
)
path = os.path.abspath(path)
return f"{path}\n"
class TCL:
def __init__(self, config: CFG, weights_file_name: str):
self.config = config
self.dependencies = []
self.vlist = ""
self.weights_file_name = weights_file_name
def add_dependency(self, name: str, file_name: str,use_weights: bool = True):
self.dependencies.append(TCL_dependency(self.config, name, file_name, self.weights_file_name,use_weights))
def set_vlist(self, vlist: str):
self.vlist = vlist
def generate(self):
tcl = ""
for dependency in self.dependencies:
tcl += f"{dependency}\n"
unique_lines = sorted(set(tcl.strip().split("\n")))
result = "\n".join(unique_lines).strip()
result = f'set {self.vlist} "\n' + result + '\n"'
return result
mapped_weights
output
Optimized_HN
Optimized_HN_mux
WT_group
Optimized_mux
Mux_wrapper
Mux
build
optimize_HN.egg-info
Sub_wrapper
dist
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from tqdm import tqdm
import pickle
from hllm.config import CFG
# %%
def generate_module(
cur_GP=0,
module_name="",
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
WW=[8, 8],
config: CFG = None,
):
module_name_suffix = f"_tp_{weights_file_name}_gp_{cur_GP}"
str = f"""
module FSM_tp_{weights_file_name}_gp_{cur_GP} #(
parameter H = 896, // hidden layer dim, 896
parameter L = 24, // layer num, 24
parameter VN = 32, // vector num, 896 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 18, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 32 // MAC output total width, SCW + WP + 4
) (
input clk,
input valid,
input fsm_rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);"""
str += r"""
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_valid; // carry reg clear if not valid
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v0_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v1_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v2_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v3_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v4_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v5_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v6_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v7_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v8_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v9_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v10_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v11_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v12_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v13_wire[VN - 1 : 0];
reg [7 : 0] state_idx; // MSB 3: state; LSB 5: idx
reg [7 : 0] next_state_idx;
wire [2 : 0] state;
wire [4 : 0] idx;
assign state = state_idx[7 : 5];
assign idx = state_idx[4 : 0];
reg CST_LOW;
genvar j;
integer i;"""
str += f"""
Mid_wrapper_tp_{weights_file_name}_gp_{cur_GP} #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(fsm_rstn),
.valid(tree_valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
"""
str += r"""
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(fsm_rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
genvar k;
generate
for (k = 0; k < VN; k = k + 1) begin
assign WT_result_v0_wire[k] = {1'b0, final_S_v0[k]} + {final_C_v0[k], 1'b0};
assign WT_result_v1_wire[k] = {1'b0, final_S_v1[k]} + {final_C_v1[k], 1'b0};
assign WT_result_v2_wire[k] = {1'b0, final_S_v2[k]} + {final_C_v2[k], 1'b0};
assign WT_result_v3_wire[k] = {1'b0, final_S_v3[k]} + {final_C_v3[k], 1'b0};
assign WT_result_v4_wire[k] = {1'b0, final_S_v4[k]} + {final_C_v4[k], 1'b0};
assign WT_result_v5_wire[k] = {1'b0, final_S_v5[k]} + {final_C_v5[k], 1'b0};
assign WT_result_v6_wire[k] = {1'b0, final_S_v6[k]} + {final_C_v6[k], 1'b0};
assign WT_result_v7_wire[k] = {1'b0, final_S_v7[k]} + {final_C_v7[k], 1'b0};
assign WT_result_v8_wire[k] = {1'b0, final_S_v8[k]} + {final_C_v8[k], 1'b0};
assign WT_result_v9_wire[k] = {1'b0, final_S_v9[k]} + {final_C_v9[k], 1'b0};
assign WT_result_v10_wire[k] = {1'b0, final_S_v10[k]} + {final_C_v10[k], 1'b0};
assign WT_result_v11_wire[k] = {1'b0, final_S_v11[k]} + {final_C_v11[k], 1'b0};
assign WT_result_v12_wire[k] = {1'b0, final_S_v12[k]} + {final_C_v12[k], 1'b0};
assign WT_result_v13_wire[k] = {1'b0, final_S_v13[k]} + {final_C_v13[k], 1'b0};
end
endgenerate
// fsm next state generation
always @(state_idx, valid) begin
case(state)
3'b000: begin
if (valid == 1) next_state_idx = 8'b00100000;
else next_state_idx = 0;
end
3'b001: begin
if (idx == SCW) next_state_idx = 8'b01000000;
else next_state_idx = state_idx + 1;
end
3'b010: next_state_idx = 8'b01100000;
3'b011: next_state_idx = 8'b10000000;
3'b100: next_state_idx = 0;
default: next_state_idx = 0;
endcase
end
// fsm state transfer
always @(posedge clk) begin
if(!fsm_rstn)
state_idx <= 0;
else
state_idx <= next_state_idx;
end
//fsm output
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
//state <= 0;
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
end else begin
if (state == 0) begin
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
/*
if (valid == 1) begin
state <= 1;
end
else begin
state <= 0;
end
*/
end
else if (state == 1) begin
tree_valid <= 1;
if (idx == 0) begin
TM_sel <= 8'b00000001;
end else begin
case (idx)
1: TM_sel <= 8'b00000010;
2: TM_sel <= 8'b00000100;
3: TM_sel <= 8'b00001000;
4: TM_sel <= 8'b00010000;
5: TM_sel <= 8'b00100000;
6: TM_sel <= 8'b01000000;
7: TM_sel <= 8'b10000000;
default: TM_sel <= 8'b10000000; // signed extension
endcase
//if (idx == SCW) begin
// state <= 2;
//end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
//idx <= idx + 1;
end
else if (state == 2) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_v0[i] <= WT_result_v0_wire[i];
WT_result_v1[i] <= WT_result_v1_wire[i];
WT_result_v2[i] <= WT_result_v2_wire[i];
WT_result_v3[i] <= WT_result_v3_wire[i];
WT_result_v4[i] <= WT_result_v4_wire[i];
WT_result_v5[i] <= WT_result_v5_wire[i];
WT_result_v6[i] <= WT_result_v6_wire[i];
WT_result_v7[i] <= WT_result_v7_wire[i];
WT_result_v8[i] <= WT_result_v8_wire[i];
WT_result_v9[i] <= WT_result_v9_wire[i];
WT_result_v10[i] <= WT_result_v10_wire[i];
WT_result_v11[i] <= WT_result_v11_wire[i];
WT_result_v12[i] <= WT_result_v12_wire[i];
WT_result_v13[i] <= WT_result_v13_wire[i];
end
//idx <= 0;
//state <= 3;
tree_valid <= 0;
end
else if (state == 3) begin
// MAC adder
//state <= 4;
end
else if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
//state <= 5;
result_valid <= 1;
end
//if (state == 5) begin
// ouput, reserved cycle
//idx <= 0;
//state <= 0;
//tree_valid <= 0;
//end
end
end
endmodule
"""
return str
# %%
def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_gp_{i}"
module = generate_module(
i,
module_name=module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
)
with open(file_name, "w") as f:
f.write(module)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
GP = int(VN / config.group_number)
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, VN, config
)
for i in range(GP)
]
for future in tqdm(as_completed(futures), total=GP):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
import os
import json
import pickle
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
from hllm.config import CFG
from hllm.optimized.turbo_optimize_hn import generate_color_graph, greedy_coloring
def process_weight(args):
index, weight, L, H, value_range, weights_file_name, path_dir = args
node, graph = generate_color_graph(L, H, value_range, weight)
os.makedirs(path_dir, exist_ok=True)
node_file = os.path.join(path_dir, f"info_tp_{weights_file_name}_vc_{index}.json")
with open(node_file, "w") as f:
json.dump({"node": node}, f)
max_color = []
for i in range(len(value_range)):
colors = greedy_coloring(graph[i], node[i])
color_file = os.path.join(
path_dir, f"info_tp_{weights_file_name}_vc_{index}_value_{i}.json"
)
max_color.append(max(colors) + 1)
with open(color_file, "w") as f:
json.dump({"color": colors}, f)
max_color_file = os.path.join(
path_dir, f"info_tp_{weights_file_name}_vc_{index}_ww.json"
)
with open(max_color_file, "w") as f:
json.dump({"ww": max_color}, f)
def run(name: str, config: CFG):
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
# print(VN, L, H)
path_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Generating color graph")
args = [
(i, weight, L, H, config.value_range, weights_file_name, path_dir)
for i, weight in enumerate(matrixs)
]
with Pool(config.num_workers) as pool:
list(tqdm(pool.imap(process_weight, args), total=VN))
print("Generating color graph at", path_dir)
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
GenerateBlock,
ForBlock,
add_parameter,
add_input,
add_output,
add_genvar,
add_assign,
add_wire,
add_body,
add_instance,
add_newline,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
L,
value_range,
WW,
):
L_width = int(np.ceil(np.log2(L)))
with ModuleBlock(module_name) as module:
# 参数
add_parameter("L", L)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
# 输入输出
add_input("LM_sel", L_width)
for i in range(len(value_range)):
add_input(f"LM_in_{i}", f"WW_{i}", "L")
add_output(f"LM_out_{i}", f"WW_{i}")
# 内部连线
for i in range(len(value_range)):
add_wire(name=f"LM_in_{i}_masked", width=f"WW_{i}", height="L")
add_wire(name=f"LM_in_{i}_masked_T", width="L", height=f"WW_{i}")
add_newline()
# LM_select_loop
add_genvar("i")
with GenerateBlock():
with ForBlock("i=0", "i<L", "i=i+1", "LM_select_loop"):
for j in range(len(value_range)):
add_body(
f"assign LM_in_{j}_masked[i]=LM_in_{j}[i] & {{WW_{j}{{LM_sel[i]}}}};",
)
add_newline()
# LM_transpose_loop_out
add_genvar("j")
add_genvar("k")
with GenerateBlock():
with ForBlock("k=0", "k<L", "k=k+1", "LM_transpose_loop_out"):
for i in range(len(value_range)):
with ForBlock(
"j=0", f"j<WW_{i}", "j=j+1", f"LM_transpose_loop_in_{i}"
):
add_assign(
f"LM_in_{i}_masked_T",
["j", "k"],
f"LM_in_{i}_masked",
["k", "j"],
)
add_newline()
# LM_reduce_or_loop
add_genvar("m")
with GenerateBlock():
for i in range(len(value_range)):
with ForBlock("m=0", f"m<WW_{i}", "m=m+1", f"LM_reduce_or_loop_{i}"):
add_body(f"assign LM_out_{i}[m] = |(LM_in_{i}_masked_T[m]);")
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name=module_name,
L=L,
value_range=config.value_range,
WW=WW,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import json
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.log import TCL
from hllm.utils import calculate_WW
# %%
def generate_module(
cur_GP=0,
module_name="",
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
config: CFG = None,
ww_list=None,
):
tcl = TCL(config, weights_file_name)
tcl.set_vlist(f"VLIST_tp_{weights_file_name}_gp_{cur_GP}")
L_width = int(np.ceil(np.log2(L)))
with ModuleBlock(module_name) as module:
GN = config.group_number
GP = int(VN / GN)
# 参数
add_parameter("H", H)
add_parameter("L", L)
add_parameter("VN", GN)
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", L_width)
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(name=f"WT_{i}_out_S", height="VN")
add_output(name=f"WT_{i}_out_C", height="VN")
# 内部连线
for i in range(GN):
sw_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"valid": "valid",
"CST_LOW": "CST_LOW",
"SW_in": "SW_in",
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i}]"
add_instance(
f"Sub_wrapper_tp_{weights_file_name}_vc_{cur_GP*GN+i}",
f"Sub_wrapper_{cur_GP*GN+i}",
None,
sw_ports,
)
tcl.add_dependency(
f"Sub_wrapper",
f"Sub_wrapper_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"WT_group",
f"WT_group_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"Mid_wrapper",
f"Mid_wrapper_tp_{weights_file_name}_gp_{cur_GP}.sv",
)
tcl.add_dependency(
f"Layer_mux",
f"Layer_mux_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"FSM",
f"FSM_tp_{weights_file_name}_gp_{cur_GP}.sv",
)
for j in range(len(value_range)):
tcl.add_dependency(
f"Mux_wrapper",
f"Mux_wrapper_tp_{weights_file_name}_vc_{cur_GP*GN+i}_value_{j}.sv",
)
tcl.add_dependency(
f"Mux",
f"Mux_tp_{weights_file_name}_vc_{cur_GP*GN+i}_value_{j}.sv",
)
for line in ww_list[cur_GP * GN + i]:
tcl.add_dependency(
f"SerialWallaceTree",
f"SerialWallaceTree{line}Input.v",
use_weights=False,
)
return module, tcl
# %%
def process_task(i, name, weights_file_name, ww_list, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.sv")
file_name_tcl = os.path.join(
file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.tcl"
)
module_name = f"{name}_tp_{weights_file_name}_gp_{i}"
module, tcl = generate_module(
i,
module_name=module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
ww_list=ww_list,
)
with open(file_name, "w") as f:
f.write(module.generate())
with open(file_name_tcl, "w") as f:
f.write(tcl.generate())
return i # 返回任务ID以显示进度
except Exception as e:
print(
f"Generating {i} failed with an error at line {sys.exc_info()[2].tb_lineno}: {e}"
)
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
GP = int(VN / config.group_number)
ww_list = []
ww_files = [
os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}_ww.json",
)
for i in range(VN)
]
for ww_file in ww_files:
with open(ww_file, "r") as f:
ww = json.load(f)
ww_list.append(ww["ww"])
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, ww_list, H, L, VN, config
)
for i in range(GP)
]
for future in tqdm(as_completed(futures), total=GP):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_body,
add_newline,
add_reg,
add_wire,
AlwaysBlock,
IfBlock,
ElseBlock,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import json
# %%
def generate_module(
matrix,
module_name="mux",
H=16,
L=5,
value_range=[-1, 1],
WW=[8, 8],
CUR_VN=0,
CUR_VALUE_INDEX=0,
CUR_CNT=0,
node=None,
color=None,
weights_file_name="",
):
with ModuleBlock(f"{module_name}") as module:
add_input("in", H)
L_width = int(np.ceil(np.log2(L)))
add_input("sel", L_width)
add_output("out")
layer_to_in_map = {}
for i, hn_in_layers in enumerate(node):
if color[i] == -1 or color[i] != CUR_CNT:
continue
for j, hn_in_layer in enumerate(hn_in_layers):
layer_to_in_map[hn_in_layer] = i
add_reg("par_out", L)
with AlwaysBlock("*"):
for i in range(L):
with IfBlock(f"sel == {L_width}'b{i:0{L_width}b}"):
if i in layer_to_in_map:
add_body(f"par_out[{i}]=in[{layer_to_in_map[i]}];")
else:
add_body(f"par_out[{i}]=0;")
with ElseBlock():
add_body(f"par_out[{i}]=0;")
add_assign("out", [], " | ".join([f"par_out[{i}]" for i in range(L)]), [])
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
node_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}.json",
)
node = json.load(open(node_file))["node"]
for j in range(len(config.value_range)):
color_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}_value_{j}.json",
)
color = json.load(open(color_file))["color"]
max_mux_port = max(color) + 1
text = ""
for k in range(max_mux_port):
text += generate_module(
matrix,
module_name=f"{name}_tp_{weights_file_name}_vc_{i}_value_{j}_color_{k}",
H=H,
L=L,
value_range=config.value_range,
node=node[j],
color=color,
CUR_VN=i,
CUR_VALUE_INDEX=j,
CUR_CNT=k,
weights_file_name=weights_file_name,
).generate()
text += "\n"
file_name = os.path.join(
file_dir,
f"{name}_tp_{weights_file_name}_vc_{i}_value_{j}.sv",
)
with open(file_name, "w") as f:
f.write(text)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
# %%
import sys
import numpy as np
from pyrilog import (
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_body,
add_newline,
add_instance,
add_reg,
add_wire,
AlwaysBlock,
IfBlock,
ElseBlock,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import json
# %%
def generate_module(
module_name="Mux_wrapper",
H=16,
L=5,
value_range=[-1, 1],
CUR_VN=0,
CUR_VALUE_INDEX=0,
max_mux_port=0,
node=None,
color=None,
weights_file_name="",
config: CFG = None,
name="",
):
with ModuleBlock(f"{module_name}") as module:
add_input("in", H)
L_width = int(np.ceil(np.log2(L)))
add_input("sel", L_width)
add_output("out", max_mux_port)
for i in range(max_mux_port):
add_instance(
module_name=f"Mux_tp_{weights_file_name}_vc_{CUR_VN}_value_{CUR_VALUE_INDEX}_color_{i}",
instance_name=f"Mux_{i}",
parameters={},
ports={"in": "in", "sel": "sel", "out": f"out[{i}]"},
)
return module
# %%
def process_task(i, name, weights_file_name, H, L, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
node_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}.json",
)
node = json.load(open(node_file))["node"]
for j in range(len(config.value_range)):
color_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}_value_{j}.json",
)
color = json.load(open(color_file))["color"]
max_mux_port = max(color) + 1
text = generate_module(
module_name=f"{name}_tp_{weights_file_name}_vc_{i}_value_{j}",
H=H,
L=L,
value_range=config.value_range,
node=node[j],
color=color,
CUR_VN=i,
CUR_VALUE_INDEX=j,
max_mux_port=max_mux_port,
weights_file_name=weights_file_name,
config=config,
name=name,
).generate()
file_name = os.path.join(
file_dir,
f"{name}_tp_{weights_file_name}_vc_{i}_value_{j}.sv",
)
with open(file_name, "w") as f:
f.write(text)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(process_task, i, name, weights_file_name, H, L, config)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import json
# %%
def generate_module(
matrix,
module_name="",
weights_file_name="",
H=16,
L=5,
VN_index=1,
value_range=[-1, 1],
WW=[8, 8],
name="",
config: CFG = None,
):
L_width = int(np.ceil(np.log2(L)))
module_name_suffix = f"_tp_{weights_file_name}_vc_{VN_index}"
with ModuleBlock(f"{module_name}") as module:
# 参数
add_parameter("H", H)
add_parameter("L", L)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", L_width)
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(
name=f"WT_{i}_out_S",
)
add_output(
name=f"WT_{i}_out_C",
)
for i in range(len(value_range)):
add_wire(
name=f"LM_out_{i}",
width=f"WW_{i}",
)
# 实例化mux_wrapper
for j in range(len(value_range)):
mw_params = {}
mw_ports = {"in": "SW_in", "sel": "LM_sel", "out": f"LM_out_{j}"}
add_instance(
f"Mux_wrapper_tp_{weights_file_name}_vc_{VN_index}_value_{j}",
f"Mux_wrapper_{j}",
mw_params,
mw_ports,
)
# 实例化WT
wt_params = {}
for i in range(len(value_range)):
wt_params[f"WW_{i}"] = f"WW_{i}"
wt_ports = {"clk": "clk", "tree_rstn": "tree_rstn", "valid": "valid"}
for i in range(len(value_range)):
wt_ports[f"WT_{i}_in"] = f"LM_out_{i}"
wt_ports[f"WT_{i}_out_S"] = f"WT_{i}_out_S"
wt_ports[f"WT_{i}_out_C"] = f"WT_{i}_out_C"
add_instance("WT_group" + module_name_suffix, "WT_group", wt_params, wt_ports)
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
WW = [0] * len(config.value_range)
for j in range(len(config.value_range)):
color_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{i}_value_{j}.json",
)
color = json.load(open(color_file))["color"]
WW[j] = max(color) + 1
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
with open(file_name, "w") as f:
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
f.write(
generate_module(
matrix,
module_name=module_name,
H=H,
weights_file_name=weights_file_name,
L=L,
VN_index=i,
value_range=config.value_range,
WW=WW,
name=name,
config=config,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name, config: CFG):
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
config: CFG = None,
):
L_width = int(np.ceil(np.log2(L)))
with ModuleBlock(module_name) as module:
GN = config.group_number
GP = int(VN / GN)
# 参数
add_parameter("H", H)
add_parameter("L", L)
add_parameter("VN", VN)
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", L_width)
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(name=f"WT_{i}_out_S", height="VN")
add_output(name=f"WT_{i}_out_C", height="VN")
# 内部连线
for i in range(GP):
sw_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"valid": "valid",
"CST_LOW": "CST_LOW",
"SW_in": "SW_in",
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i*GN+GN-1}:{i*GN}]"
add_instance(
f"Mid_wrapper_tp_{weights_file_name}_gp_{i}",
f"Mid_wrapper_{i}",
None,
sw_ports,
)
return module
# %%
def process_task(i, name, weights_file_name, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}.sv")
module_name = f"{name}_tp_{weights_file_name}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task,
i,
name,
weights_file_name,
H,
L,
VN,
config,
)
for i in range(1)
]
for future in tqdm(as_completed(futures), total=1):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_body,
add_newline,
add_instance,
add_reg,
add_wire,
AlwaysBlock,
IfBlock,
ElseBlock,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
import json
# %%
def generate_module(
module_name="mux_wrapper",
H=16,
L=5,
value_range=[-1, 1],
CUR_VN=0,
weights_file_name="",
config: CFG = None,
name="",
):
with ModuleBlock(f"{module_name}") as module:
ww = []
for i in range(len(value_range)):
color_file = os.path.join(
config.output_dir,
"info",
weights_file_name,
f"info_tp_{weights_file_name}_vc_{CUR_VN}_value_{i}.json",
)
color = json.load(open(color_file))["color"]
max_mux_port = max(color) + 1
ww.append(max_mux_port)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", ww[i])
add_input("clk")
add_input("tree_rstn")
add_input("valid")
for i in range(len(value_range)):
add_input(f"WT_{i}_in", f"WW_{i}")
add_output(
f"WT_{i}_out_S",
)
add_output(
f"WT_{i}_out_C",
)
# 内部连线华莱士树
for i in range(len(value_range)):
wallace_name = f"SerialWallaceTree{ww[i]}Input"
wallace_port = {
"clk": "clk",
"rstn": "tree_rstn",
"valid": "valid",
"addends": f"WT_{i}_in",
"out_S": f"WT_{i}_out_S",
"out_Cout": f"WT_{i}_out_C",
}
add_instance(wallace_name, f"serial_wallace_tree_{i}", {}, wallace_port)
return module
# %%
def process_task(i, name, weights_file_name, H, L, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(
file_dir,
f"{name}_tp_{weights_file_name}_vc_{i}.sv",
)
with open(file_name, "w") as f:
text = generate_module(
module_name=f"{name}_tp_{weights_file_name}_vc_{i}",
H=H,
L=L,
value_range=config.value_range,
CUR_VN=i,
weights_file_name=weights_file_name,
config=config,
name=name,
).generate()
f.write(text)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
print(f"Loading {weights_file_name}")
matrixs = pickle.load(f)
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(process_task, i, name, weights_file_name, H, L, config)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <set>
#include <stdexcept>
#include <vector>
using namespace std;
namespace py = pybind11;
vector<int> greedyGraphColoring(const vector<vector<int>> &adjMatrix, const vector<vector<int>> &layer)
{
int n = adjMatrix.size();
vector<int> colors(n, -1); // 存储染色结果,初始-1表示未染色
// 检查邻接矩阵有效性
for (const auto &row : adjMatrix)
{
if (row.size() != n)
{
throw invalid_argument("邻接矩阵必须是方阵");
}
}
// 遍历所有节点
for (int node = 0; node < n; ++node)
{
// 如果该节点对应的layer为空,保持颜色为-1
if (layer[node].empty()) {
continue;
}
set<int> usedColors;
// 检查所有相邻节点的已用颜色
for (int neighbor = 0; neighbor < n; ++neighbor)
{
if (adjMatrix[node][neighbor] && colors[neighbor] != -1)
{
usedColors.insert(colors[neighbor]);
}
}
// 寻找最小可用颜色
int color = 0;
while (true)
{
if (usedColors.find(color) == usedColors.end())
{
colors[node] = color;
break;
}
color++;
}
}
return colors;
}
bool hasIntersection(const vector<int> &node1, const vector<int> &node2)
{
bool bucket[100] = {false}; // 初始化桶数组
// 将node1的元素放入桶中
for (int elem : node1)
{
bucket[elem] = true;
}
// 检查node2的元素是否在桶中存在
for (int elem : node2)
{
if (bucket[elem])
{
return true;
}
}
return false;
}
tuple<vector<vector<vector<int>>>, vector<vector<vector<int>>>>
generateColorGraph(int L, int W, const vector<double> &value_range,
py::array_t<int> &matrix)
{
auto buf = matrix.request();
int *ptr = static_cast<int *>(buf.ptr);
// 转置矩阵
vector<vector<int>> transposed(W, vector<int>(L));
for (int i = 0; i < L; i++)
{
for (int j = 0; j < W; j++)
{
transposed[j][i] = ptr[i * W + j];
}
}
// 初始化node和graph
vector<vector<vector<int>>> node(value_range.size(), vector<vector<int>>(W));
vector<vector<vector<int>>> graph(value_range.size(),
vector<vector<int>>(W, vector<int>(W, 0)));
// 构建node
for (int i = 0; i < W; i++)
{
for (int j = 0; j < L; j++)
{
int val = transposed[i][j];
if (val != -1)
{
node[val][i].push_back(j);
}
}
}
// 构建graph
for (size_t i = 0; i < value_range.size(); i++)
{
for (int j = 0; j < W; j++)
{
for (int k = 0; k < W; k++)
{
if (!node[i][j].empty() && !node[i][k].empty())
{
graph[i][j][k] = hasIntersection(node[i][j], node[i][k]) ? 1 : 0;
}
}
}
}
return make_tuple(node, graph);
}
PYBIND11_MODULE(turbo_optimize_hn, m)
{
m.doc() = "图着色贪心算法模块";
m.def("greedy_coloring", &greedyGraphColoring, "基于贪心算法的图着色实现",
py::arg("adj_matrix"), py::arg("layer"));
m.def("generate_color_graph", &generateColorGraph, "生成颜色图", py::arg("L"),
py::arg("W"), py::arg("value_range"), py::arg("matrix"));
}
\ No newline at end of file
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from tqdm import tqdm
import pickle
from hllm.config import CFG
# %%
def generate_module(
cur_GP=0,
module_name="",
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
WW=[8, 8],
config: CFG = None,
):
module_name_suffix = f"_tp_{weights_file_name}_gp_{cur_GP}"
str = f"""
module FSM_tp_{weights_file_name}_gp_{cur_GP} #(
parameter H = 896, // hidden layer dim, 896
parameter L = 24, // layer num, 24
parameter VN = 32, // vector num, 896 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 18, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 32 // MAC output total width, SCW + WP + 4
) (
input clk,
input valid,
input fsm_rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);"""
str += r"""
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_valid; // carry reg clear if not valid
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v0_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v1_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v2_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v3_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v4_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v5_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v6_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v7_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v8_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v9_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v10_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v11_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v12_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v13_wire[VN - 1 : 0];
reg [7 : 0] state_idx; // MSB 3: state; LSB 5: idx
reg [7 : 0] next_state_idx;
wire [2 : 0] state;
wire [4 : 0] idx;
assign state = state_idx[7 : 5];
assign idx = state_idx[4 : 0];
reg CST_LOW;
genvar j;
integer i;"""
str += f"""
Mid_wrapper_tp_{weights_file_name}_gp_{cur_GP} #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(fsm_rstn),
.valid(tree_valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
"""
str += r"""
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(fsm_rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
genvar k;
generate
for (k = 0; k < VN; k = k + 1) begin
assign WT_result_v0_wire[k] = {1'b0, final_S_v0[k]} + {final_C_v0[k], 1'b0};
assign WT_result_v1_wire[k] = {1'b0, final_S_v1[k]} + {final_C_v1[k], 1'b0};
assign WT_result_v2_wire[k] = {1'b0, final_S_v2[k]} + {final_C_v2[k], 1'b0};
assign WT_result_v3_wire[k] = {1'b0, final_S_v3[k]} + {final_C_v3[k], 1'b0};
assign WT_result_v4_wire[k] = {1'b0, final_S_v4[k]} + {final_C_v4[k], 1'b0};
assign WT_result_v5_wire[k] = {1'b0, final_S_v5[k]} + {final_C_v5[k], 1'b0};
assign WT_result_v6_wire[k] = {1'b0, final_S_v6[k]} + {final_C_v6[k], 1'b0};
assign WT_result_v7_wire[k] = {1'b0, final_S_v7[k]} + {final_C_v7[k], 1'b0};
assign WT_result_v8_wire[k] = {1'b0, final_S_v8[k]} + {final_C_v8[k], 1'b0};
assign WT_result_v9_wire[k] = {1'b0, final_S_v9[k]} + {final_C_v9[k], 1'b0};
assign WT_result_v10_wire[k] = {1'b0, final_S_v10[k]} + {final_C_v10[k], 1'b0};
assign WT_result_v11_wire[k] = {1'b0, final_S_v11[k]} + {final_C_v11[k], 1'b0};
assign WT_result_v12_wire[k] = {1'b0, final_S_v12[k]} + {final_C_v12[k], 1'b0};
assign WT_result_v13_wire[k] = {1'b0, final_S_v13[k]} + {final_C_v13[k], 1'b0};
end
endgenerate
// fsm next state generation
always @(state_idx, valid) begin
case(state)
3'b000: begin
if (valid == 1) next_state_idx = 8'b00100000;
else next_state_idx = 0;
end
3'b001: begin
if (idx == SCW) next_state_idx = 8'b01000000;
else next_state_idx = state_idx + 1;
end
3'b010: next_state_idx = 8'b01100000;
3'b011: next_state_idx = 8'b10000000;
3'b100: next_state_idx = 0;
default: next_state_idx = 0;
endcase
end
// fsm state transfer
always @(posedge clk) begin
if(!fsm_rstn)
state_idx <= 0;
else
state_idx <= next_state_idx;
end
//fsm output
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
//state <= 0;
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
end else begin
if (state == 0) begin
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
/*
if (valid == 1) begin
state <= 1;
end
else begin
state <= 0;
end
*/
end
else if (state == 1) begin
tree_valid <= 1;
if (idx == 0) begin
TM_sel <= 8'b00000001;
end else begin
case (idx)
1: TM_sel <= 8'b00000010;
2: TM_sel <= 8'b00000100;
3: TM_sel <= 8'b00001000;
4: TM_sel <= 8'b00010000;
5: TM_sel <= 8'b00100000;
6: TM_sel <= 8'b01000000;
7: TM_sel <= 8'b10000000;
default: TM_sel <= 8'b10000000; // signed extension
endcase
//if (idx == SCW) begin
// state <= 2;
//end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
//idx <= idx + 1;
end
else if (state == 2) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_v0[i] <= WT_result_v0_wire[i];
WT_result_v1[i] <= WT_result_v1_wire[i];
WT_result_v2[i] <= WT_result_v2_wire[i];
WT_result_v3[i] <= WT_result_v3_wire[i];
WT_result_v4[i] <= WT_result_v4_wire[i];
WT_result_v5[i] <= WT_result_v5_wire[i];
WT_result_v6[i] <= WT_result_v6_wire[i];
WT_result_v7[i] <= WT_result_v7_wire[i];
WT_result_v8[i] <= WT_result_v8_wire[i];
WT_result_v9[i] <= WT_result_v9_wire[i];
WT_result_v10[i] <= WT_result_v10_wire[i];
WT_result_v11[i] <= WT_result_v11_wire[i];
WT_result_v12[i] <= WT_result_v12_wire[i];
WT_result_v13[i] <= WT_result_v13_wire[i];
end
//idx <= 0;
//state <= 3;
tree_valid <= 0;
end
else if (state == 3) begin
// MAC adder
//state <= 4;
end
else if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
//state <= 5;
result_valid <= 1;
end
//if (state == 5) begin
// ouput, reserved cycle
//idx <= 0;
//state <= 0;
//tree_valid <= 0;
//end
end
end
endmodule
"""
return str
# %%
def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_gp_{i}"
module = generate_module(
i,
module_name=module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
)
with open(file_name, "w") as f:
f.write(module)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
GP = int(VN / config.group_number)
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, VN, config
)
for i in range(GP)
]
for future in tqdm(as_completed(futures), total=GP):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import ModuleBlock, add_parameter, add_input, add_output, add_assign
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW, find_index
# %%
def generate_module(
matrix,
module_name="HN",
H=16,
L=5,
value_range=[-1, 1],
WW=[8, 8],
):
# with VerilogGenerator() as generator:
with ModuleBlock(f"{module_name}") as module:
add_parameter("H", H)
add_parameter("L", L)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
add_input("HN_in", "H")
add_input("CST_LOW")
for i in range(len(value_range)):
add_output(
f"HN_out_{i}",
f"WW_{i}",
"L",
)
# 内部连线
for i, layer in enumerate(matrix):
weight_cnt = [0] * len(value_range)
for j, weight in enumerate(layer):
# 跳0
if abs(weight) < 1e-3:
continue
try:
index = find_index(value_range, weight)
except ValueError:
print(f"weight {weight} not found")
continue
add_assign(
f"HN_out_{index}",
[i, weight_cnt[index]],
"HN_in",
[j],
)
weight_cnt[index] += 1
for j in range(len(weight_cnt)):
while weight_cnt[j] < WW[j]:
add_assign(f"HN_out_{j}", [i, weight_cnt[j]], "CST_LOW", [])
weight_cnt[j] += 1
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
text = generate_module(
matrix,
module_name=module_name,
H=H,
L=L,
value_range=config.value_range,
WW=WW,
).generate()
f.write(text)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
GenerateBlock,
ForBlock,
add_parameter,
add_input,
add_output,
add_genvar,
add_assign,
add_wire,
add_body,
add_instance,
add_newline,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
L,
value_range,
WW,
):
with ModuleBlock(module_name) as module:
# 参数
add_parameter("L", L)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
# 输入输出
add_input("LM_sel", "L")
for i in range(len(value_range)):
add_input(f"LM_in_{i}", f"WW_{i}", "L")
add_output(f"LM_out_{i}", f"WW_{i}")
# 内部连线
for i in range(len(value_range)):
add_wire(name=f"LM_in_{i}_masked", width=f"WW_{i}", height="L")
add_wire(name=f"LM_in_{i}_masked_T", width="L", height=f"WW_{i}")
add_newline()
# LM_select_loop
add_genvar("i")
with GenerateBlock():
with ForBlock("i=0", "i<L", "i=i+1", "LM_select_loop"):
for j in range(len(value_range)):
add_body(
f"assign LM_in_{j}_masked[i]=LM_in_{j}[i] & {{WW_{j}{{LM_sel[i]}}}};",
)
add_newline()
# LM_transpose_loop_out
add_genvar("j")
add_genvar("k")
with GenerateBlock():
with ForBlock("k=0", "k<L", "k=k+1", "LM_transpose_loop_out"):
for i in range(len(value_range)):
with ForBlock(
"j=0", f"j<WW_{i}", "j=j+1", f"LM_transpose_loop_in_{i}"
):
add_assign(
f"LM_in_{i}_masked_T",
["j", "k"],
f"LM_in_{i}_masked",
["k", "j"],
)
add_newline()
# LM_reduce_or_loop
add_genvar("m")
with GenerateBlock():
for i in range(len(value_range)):
with ForBlock("m=0", f"m<WW_{i}", "m=m+1", f"LM_reduce_or_loop_{i}"):
add_body(f"assign LM_out_{i}[m] = |(LM_in_{i}_masked_T[m]);")
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name=module_name,
L=L,
value_range=config.value_range,
WW=WW,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.log import TCL, TCL_dependency
from hllm.utils import calculate_WW
# %%
def generate_module(
cur_GP=0,
module_name="",
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
config: CFG = None,
ww_list=None,
):
tcl = TCL(config, weights_file_name)
tcl.set_vlist(f"VLIST_tp_{weights_file_name}_gp_{cur_GP}")
with ModuleBlock(module_name) as module:
GN = config.group_number
GP = int(VN / GN)
# 参数
add_parameter("H", H)
add_parameter("L", L)
add_parameter("VN", GN)
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", "L")
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(name=f"WT_{i}_out_S", height="VN")
add_output(name=f"WT_{i}_out_C", height="VN")
# 内部连线
for i in range(GN):
sw_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"valid": "valid",
"CST_LOW": "CST_LOW",
"SW_in": "SW_in",
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i}]"
add_instance(
f"Sub_wrapper_tp_{weights_file_name}_vc_{cur_GP*GN+i}",
f"Sub_wrapper_{cur_GP*GN+i}",
None,
sw_ports,
)
tcl.add_dependency(
f"HN",
f"HN_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"WT_group",
f"WT_group_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"Layer_mux",
f"Layer_mux_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"Sub_wrapper",
f"Sub_wrapper_tp_{weights_file_name}_vc_{cur_GP*GN+i}.sv",
)
tcl.add_dependency(
f"Mid_wrapper",
f"Mid_wrapper_tp_{weights_file_name}_gp_{cur_GP}.sv",
)
tcl.add_dependency(
f"FSM",
f"FSM_tp_{weights_file_name}_gp_{cur_GP}.sv",
)
for line in ww_list[cur_GP * GN + i]:
tcl.add_dependency(
f"SerialWallaceTree",
f"SerialWallaceTree{line}Input.v",
use_weights=False,
)
return module, tcl
# %%
def process_task(i, name, weights_file_name, ww_list, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.sv")
file_name_tcl = os.path.join(
file_dir, f"{name}_tp_{weights_file_name}_gp_{i}.tcl"
)
module_name = f"{name}_tp_{weights_file_name}_gp_{i}"
module, tcl = generate_module(
i,
module_name=module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
ww_list=ww_list,
)
with open(file_name, "w") as f:
f.write(module.generate())
with open(file_name_tcl, "w") as f:
f.write(tcl.generate())
return i # 返回任务ID以显示进度
except Exception as e:
print(
f"Generating {i} failed with an error at line {sys.exc_info()[2].tb_lineno}: {e}"
)
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
GP = int(VN / config.group_number)
ww_list = []
ww_files = [
os.path.join(
config.output_dir,
"WW",
weights_file_name,
f"WW_tp_{weights_file_name}_vc_{i}.txt",
)
for i in range(VN)
]
for ww_file in ww_files:
with open(ww_file, "r") as f:
ww = []
for line in f:
ww.append(int(line.strip()))
ww_list.append(ww)
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, ww_list, H, L, VN, config
)
for i in range(GP)
]
for future in tqdm(as_completed(futures), total=GP):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import numpy as np
from pyrilog import (
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW
# %%
def generate_module(
matrix,
module_name,
H=16,
L=5,
value_range=[-1, 1],
WW=[8, 8],
weights_file_name=None,
config: CFG = None,
VN_index=1,
):
module_name_suffix = f"_tp_{weights_file_name}_vc_{VN_index}"
with ModuleBlock(module_name) as module:
# 参数
add_parameter("H", H)
add_parameter("L", L)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", "L")
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(
name=f"WT_{i}_out_S",
)
add_output(
name=f"WT_{i}_out_C",
)
# 内部连线
# add_wire("CST_LOW")
# add_assign("CST_LOW", [], 0, [])
for i in range(len(value_range)):
add_wire(name=f"HN_out_{i}", width=f"WW_{i}", height="L")
add_wire(
name=f"LM_out_{i}",
width=f"WW_{i}",
)
# 实例化HN
hn_params = {
"H": "H",
"L": "L",
}
for i in range(len(value_range)):
hn_params[f"WW_{i}"] = f"WW_{i}"
hn_ports = {
"HN_in": "SW_in",
"CST_LOW": "CST_LOW",
}
for i in range(len(value_range)):
hn_ports[f"HN_out_{i}"] = f"HN_out_{i}"
add_instance("HN" + module_name_suffix, "HN", hn_params, hn_ports)
# 实例化LM
lm_params = {
"L": L,
}
for i in range(len(value_range)):
lm_params[f"WW_{i}"] = f"WW_{i}"
lm_ports = {
"LM_sel": "LM_sel",
}
for i in range(len(value_range)):
lm_ports[f"LM_in_{i}"] = f"HN_out_{i}"
lm_ports[f"LM_out_{i}"] = f"LM_out_{i}"
add_instance("Layer_mux" + module_name_suffix, "Layer_mux", lm_params, lm_ports)
# 实例化WT
wt_params = {}
for i in range(len(value_range)):
wt_params[f"WW_{i}"] = f"WW_{i}"
wt_ports = {"clk": "clk", "tree_rstn": "tree_rstn", "valid": "valid"}
for i in range(len(value_range)):
wt_ports[f"WT_{i}_in"] = f"LM_out_{i}"
wt_ports[f"WT_{i}_out_S"] = f"WT_{i}_out_S"
wt_ports[f"WT_{i}_out_C"] = f"WT_{i}_out_C"
add_instance("WT_group" + module_name_suffix, "WT_group", wt_params, wt_ports)
return module
# %%
def process_task(i, name, weights_file_name, matrix, H, L, config):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
module_name = f"{name}_tp_{weights_file_name}_vc_{i}"
with open(file_name, "w") as f:
f.write(
generate_module(
matrix,
module_name=module_name,
H=H,
L=L,
value_range=config.value_range,
WW=WW,
weights_file_name=weights_file_name,
config=config,
VN_index=i,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name)
print("Files generated in", file_dir)
import argparse
import sys
import math
import os
from tqdm import tqdm
from hllm.config import CFG
def gen_fulladder():
code = """module FullAdder(
input A, // First input bit
input B, // Second input bit
input Cin, // Carry input bit
output S, // Sum output bit
output Cout // Carry output bit
);
assign S = A ^ B ^ Cin;
assign Cout = (A & B) | (B & Cin) | (A & Cin);
endmodule
"""
return code
def gen_wallace_tree_config(num_addends):
full_adder_list = []
remainder_list = []
total_input_list = []
while num_addends > 2:
full_adders_used = num_addends // 3
remaining_addends = num_addends % 3
full_adder_list.append(full_adders_used)
remainder_list.append(remaining_addends)
total_input_list.append(num_addends)
num_addends = full_adders_used * 2 + remaining_addends
return full_adder_list, remainder_list, total_input_list
def gen_wallacetree(num_addends, full_adder_list, remainder_list, total_input_list):
code = ""
cout_cin_code = ""
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1: # the final bit we will manually manage
cout_cin_code += f" output [{full_adder_count} - 1 : 0] L{i}_Cout,\n"
cout_cin_code += f" input [{full_adder_count} - 1 : 0] L{i+1}_Cin,\n"
module_head_code = f"""module WallaceTree{num_addends}Input(
input [{num_addends} - 1 : 0] addends,
{cout_cin_code}
output final_Cout,
output final_S
);
"""
code += module_head_code
for i, (full_adder_count, remainder_count, total_input_count) in enumerate(
zip(full_adder_list, remainder_list, total_input_list)
):
code += f" wire [{total_input_count} - 1 : 0] L{i}_all_inputs;\n"
if i == 0:
code += f" assign L{i}_all_inputs = addends;\n"
else:
last_remainder_count = remainder_list[i - 1]
if last_remainder_count == 0:
concat_code = f"{{L{i-1}_S, L{i}_Cin}}"
else:
concat_code = f"{{L{i-1}_S, L{i}_Cin, L{i-1}_remainder}}"
code += f" assign L{i}_all_inputs = {concat_code};\n"
if remainder_count != 0:
code += f" wire [{remainder_count} - 1 : 0] L{i}_remainder;\n"
code += f" assign L{i}_remainder = L{i}_all_inputs[{total_input_count} - 1 : {total_input_count} - {remainder_count}];\n"
if i != len(full_adder_list) - 1: # otherwise directly assign to output pin
code += f" wire [{full_adder_count} - 1 : 0] L{i}_S;\n"
cout_code = f"L{i}_Cout" if i != len(full_adder_list) - 1 else "final_Cout"
S_code = f"L{i}_S" if i != len(full_adder_list) - 1 else "final_S"
code += f"""\
FullAdder L{i}_adders [{full_adder_count} - 1 : 0](
.A(L{i}_all_inputs[{full_adder_count} * 3 - 1 : {full_adder_count} * 2]),
.B(L{i}_all_inputs[{full_adder_count} * 2 - 1 : {full_adder_count}]),
.Cin(L{i}_all_inputs[{full_adder_count} - 1 : 0]),
.Cout({cout_code}),
.S({S_code})
);
"""
code += "endmodule\n\n"
return code
def gen_serialwallacetree(num_addends, full_adder_list):
code = ""
code += f"""module SerialWallaceTree{num_addends}Input(
input clk,
input rstn,
input valid,
input [{num_addends} - 1 : 0] addends,
output out_S,
output out_Cout
);
"""
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
code += f" wire [{full_adder_count} - 1 : 0] L{i}_Cout;\n"
code += f" wire [{full_adder_count} - 1 : 0] L{i+1}_Cin;\n"
code += f" reg [{full_adder_count} - 1 : 0] L{i}_Cout_L{i+1}_Cin_reg;\n"
code += f" assign L{i+1}_Cin = L{i}_Cout_L{i+1}_Cin_reg;\n\n"
cin_cout_assign_code = ""
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
cin_cout_assign_code += f" .L{i}_Cout(L{i}_Cout),\n"
cin_cout_assign_code += f" .L{i+1}_Cin(L{i+1}_Cin),\n"
code += " wire final_S, final_Cout;\n"
code += " assign out_S = final_S & valid;\n"
code += " assign out_Cout = final_Cout & valid;\n"
code += f"""\
WallaceTree{num_addends}Input u_WallaceTree{num_addends}Input(
.addends(addends),
{cin_cout_assign_code}
.final_S(final_S),
.final_Cout(final_Cout)
);
"""
reset_code = ""
reg_assign_code = ""
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
reset_code += (
f" L{i}_Cout_L{i+1}_Cin_reg <= {full_adder_count}'b0;\n"
)
reg_assign_code += f" L{i}_Cout_L{i+1}_Cin_reg <= L{i}_Cout&{{{full_adder_count}{{valid}}}};\n"
code += f"""\
always @ (posedge clk or negedge rstn) begin
if (!rstn) begin
{reset_code}
end
else begin
{reg_assign_code}
end
end
"""
code += "endmodule\n\n"
return code
def run(name: str, config: CFG):
# Setup argument parser
# parser = argparse.ArgumentParser(
# description="Generate Verilog code for Wallace Tree configurations."
# )
# parser.add_argument(
# "num_addends", type=int, help="Number of addends for the Wallace Tree."
# )
for i in tqdm(range(1, 4000)):
# Generate the configuration for the Wallace Tree
full_adder_list, remainder_list, total_input_list = gen_wallace_tree_config(i)
# Generate the FullAdder module
full_adder_code = gen_fulladder()
# Generate the basic Wallace Tree
wallace_tree_code = gen_wallacetree(
i, full_adder_list, remainder_list, total_input_list
)
serial_wallace_tree_code = gen_serialwallacetree(i, full_adder_list)
# Prepare the output code
if i == 1:
output_code = full_adder_code + wallace_tree_code + serial_wallace_tree_code
else:
output_code = wallace_tree_code + serial_wallace_tree_code
# Output the code to a file
os.makedirs(os.path.join(config.output_dir, name), exist_ok=True)
output_filename = os.path.join(
config.output_dir, name, f"SerialWallaceTree{i}Input.v"
)
with open(output_filename, "w") as file:
file.write(output_code)
file_dir = os.path.join(config.output_dir, name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.utils import calculate_WW
# %%
def generate_module(
module_name,
H=16,
L=5,
VN=512,
value_range=[-1, 1],
weights_file_name=None,
config: CFG = None,
):
with ModuleBlock(module_name) as module:
GN = config.group_number
GP = int(VN / GN)
# 参数
add_parameter("H", H)
add_parameter("L", L)
add_parameter("VN", VN)
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
add_input("CST_LOW")
add_input("LM_sel", "L")
add_input("SW_in", "H")
for i in range(len(value_range)):
add_output(name=f"WT_{i}_out_S", height="VN")
add_output(name=f"WT_{i}_out_C", height="VN")
# 内部连线
for i in range(GP):
sw_ports = {
"clk": "clk",
"tree_rstn": "tree_rstn",
"valid": "valid",
"CST_LOW": "CST_LOW",
"SW_in": "SW_in",
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i*GN+GN-1}:{i*GN}]"
add_instance(
f"Mid_wrapper_tp_{weights_file_name}_gp_{i}",
f"Mid_wrapper_{i}",
None,
sw_ports,
)
return module
# %%
def process_task(i, name, weights_file_name, H, L, VN, config: CFG):
try:
file_dir = os.path.join(config.output_dir, name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}.sv")
module_name = f"{name}_tp_{weights_file_name}"
with open(file_name, "w") as f:
f.write(
generate_module(
module_name,
H=H,
L=L,
VN=VN,
value_range=config.value_range,
weights_file_name=weights_file_name,
config=config,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task,
i,
name,
weights_file_name,
H,
L,
VN,
config,
)
for i in range(1)
]
for future in tqdm(as_completed(futures), total=1):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
GenerateBlock,
ForBlock,
add_parameter,
add_input,
add_output,
add_genvar,
add_assign,
add_wire,
add_body,
add_instance,
add_newline,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from tqdm import tqdm
from hllm.config import CFG
import pickle
from hllm.utils import calculate_WW, find_index
def generate_module(
matrix,
module_name="",
H=16,
L=5,
value_range=[-1, 1],
WW=[8, 8],
):
with ModuleBlock(f"{module_name}") as module:
# 参数)
for i in range(len(value_range)):
add_parameter(f"WW_{i}", WW[i])
# 输入输出
add_input("clk")
add_input("tree_rstn")
add_input("valid")
for i in range(len(value_range)):
add_input(f"WT_{i}_in", f"WW_{i}")
add_output(
f"WT_{i}_out_S",
)
add_output(
f"WT_{i}_out_C",
)
# 内部连线华莱士树
for i in range(len(value_range)):
wallace_name = f"SerialWallaceTree{WW[i]}Input"
wallace_port = {
"clk": "clk",
"rstn": "tree_rstn",
"valid": "valid",
"addends": f"WT_{i}_in",
"out_S": f"WT_{i}_out_S",
"out_Cout": f"WT_{i}_out_C",
}
add_instance(wallace_name, f"serial_wallace_tree_{i}", {}, wallace_port)
return module
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.sv")
with open(file_name, "w") as f:
f.write(
generate_module(
matrix,
module_name=f"{name}_tp_{weights_file_name}_vc_{i}",
H=H,
L=L,
value_range=config.value_range,
WW=WW,
).generate()
)
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
import sys
import numpy as np
from pyrilog import (
VerilogGenerator,
ModuleBlock,
add_parameter,
add_input,
add_output,
add_assign,
add_wire,
add_instance,
)
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
from hllm.config import CFG
from tqdm import tqdm
import pickle
from hllm.log import TCL
from hllm.utils import calculate_WW
def process_task(i, name, weights_file_name, matrix, H, L, config: CFG = None):
try:
WW = calculate_WW(matrix, config.value_range)
file_dir = os.path.join(config.output_dir, name, weights_file_name)
os.makedirs(file_dir, exist_ok=True)
file_name = os.path.join(file_dir, f"{name}_tp_{weights_file_name}_vc_{i}.txt")
with open(file_name, "w") as f:
for ww in WW:
f.write(str(ww) + "\n")
return i # 返回任务ID以显示进度
except Exception as e:
print(f"Generating {i} failed with an error: {e}")
return None
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
with ProcessPoolExecutor(max_workers=config.num_workers) as executor:
futures = [
executor.submit(
process_task, i, name, weights_file_name, matrixs[i], H, L, config
)
for i in range(VN)
]
for future in tqdm(as_completed(futures), total=VN):
try:
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
import numpy as np
import os
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from hllm.config import CFG
def calculate_WW(matrix: np.array, value_range):
"""计算每个value_range值在矩阵中每行出现的最大次数"""
WW = [0] * len(value_range)
for i in range(len(value_range)):
WW[i] = max(
[len([x for x in row if abs(x - value_range[i]) <= 0.01]) for row in matrix]
)
return WW
def find_index(arr, target, epsilon=1e-3):
"""在数组中查找最接近目标值的索引"""
arr = np.array(arr) # 转换为numpy数组
diff = np.abs(arr - target) # 计算差值数组
min_diff = np.min(diff) # 找到最小的差值
if min_diff < epsilon: # 如果最小差值在允许的误差范围内
return np.where(diff == min_diff)[0][0] # 返回第一个匹配的索引
raise ValueError("No match found") # 如果没有找到匹配项,则引发异常
import sys
import numpy as np
import pickle
import os
from tqdm import tqdm
from hllm.config import CFG
def to_8bit_binary(val):
if val < 0:
return f"{(1 << 8) + val:08b}"
else:
return f"{val:08b}"
def run(config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, weights_file_name)
filename_pkl = os.path.join(file_dir, "activation.pkl")
filename_txt = os.path.join(file_dir, "activation.txt")
filename_bin_txt = os.path.join(file_dir, "activation-bin.txt")
if os.path.exists(filename_pkl) and not config.verify_generate_activation_on_exist:
print(f"Activation file {filename_pkl} already exists")
return
print(f"Generating activation for {weights_file_name} in {file_dir}")
os.makedirs(file_dir, exist_ok=True)
with open(weights_file, mode="rb") as f:
weights = pickle.load(f)
shape = weights.shape
length = shape[2]
activation = np.random.randint(-128, 128, (1, length))
with open(filename_pkl, "wb") as f:
pickle.dump(activation, f)
with open(filename_txt, "w") as f:
for val in activation[0]:
f.write(f"{val}\n")
f.write("\n")
with open(filename_bin_txt, "w") as f:
for val in activation[0]:
f.write(f"{to_8bit_binary(val)}\n")
f.write("\n")
import sys
import numpy as np
import pickle
import os
from tqdm import tqdm
from hllm.config import CFG
def to_8bit_binary(val):
if val < 0:
return f"{(1 << 8) + val:08b}"
else:
return f"{val:08b}"
def get_bit(num, i):
if i < 0:
return 0
return (num >> i) & 1
def run(config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, weights_file_name)
print("开始进行激活测试")
print(f"读取权重文件{weights_file_name}")
with open(weights_file, "rb") as f:
weights = pickle.load(f)
activation_file = os.path.join(file_dir, "activation.pkl")
with open(activation_file, "rb") as f:
activation = pickle.load(f)
results_txt = os.path.join(file_dir, "result.txt")
with open(results_txt, "w") as f:
for layer in weights:
for i in range(8):
activation_bit = get_bit(activation, i)
tem = np.matmul(activation_bit, layer.T)
for val in tem[0]:
f.write(f"{val} ")
f.write("\n")
# with open(result_manual_txt, "w") as f:
# for layer in matrixs:
# tem = np.zeros((1, layer.shape[1]))
# for i, row in enumerate(layer):
# for j, val in enumerate(row):
# tem[0][j] += activation[0][i] * val
# f.write(f"{tem}\n")
print(f"结果写入{results_txt}")
from hllm.config import CFG
def run_origin(config: CFG):
import hllm.origin.generate_layer_mux as generate_layer_mux
import hllm.origin.generate_hn as generate_hn
import hllm.origin.generate_mid_wrapper as generate_mid_wrapper
import hllm.origin.generate_fsm as generate_fsm
import hllm.origin.generate_sub_wrapper as generate_sub_wrapper
import hllm.origin.generate_wallace as generate_wallace
import hllm.origin.generate_wrappers as generate_wrappers
import hllm.origin.generate_wt_group as generate_wt_group
import hllm.origin.generate_ww as generate_ww
config.output_dir = "outputs-qwen/origin"
generate_ww.run(name="WW", config=config)
generate_layer_mux.run(name="Layer_mux", config=config)
generate_hn.run(name="HN", config=config)
generate_mid_wrapper.run(name="Mid_wrapper", config=config)
generate_fsm.run(name="FSM", config=config)
generate_sub_wrapper.run(name="Sub_wrapper", config=config)
generate_wallace.run(name="SerialWallaceTree", config=config)
generate_wrappers.run(name="Wrappers", config=config)
generate_wt_group.run(name="WT_group", config=config)
def run_optimized(config: CFG):
import hllm.origin.generate_wallace as generate_wallace
import hllm.optimized.generate_info as generate_info
import hllm.optimized.generate_mux_wrapper as generate_mux_wrapper
import hllm.optimized.generate_mux as generate_mux
import hllm.optimized.generate_sub_wrapper as generate_sub_wrapper
import hllm.optimized.generate_wt_group as generate_wt_group
import hllm.optimized.generate_mid_wrapper as generate_mid_wrapper
import hllm.optimized.generate_wrappers as generate_wrappers
import hllm.optimized.generate_layer_mux as generate_layer_mux
import hllm.optimized.generate_fsm as generate_fsm
config.output_dir = "outputs-qwen/optimized"
generate_info.run(name="info", config=config)
generate_mux_wrapper.run(name="Mux_wrapper", config=config)
generate_mux.run(name="Mux", config=config)
generate_sub_wrapper.run(name="Sub_wrapper", config=config)
generate_wt_group.run(name="WT_group", config=config)
generate_mid_wrapper.run(name="Mid_wrapper", config=config)
generate_wallace.run(name="SerialWallaceTree", config=config)
generate_wrappers.run(name="Wrappers", config=config)
generate_layer_mux.run(name="Layer_mux", config=config)
generate_fsm.run(name="FSM", config=config)
def run_weights_preprocess(config: CFG):
import hllm.eda.generate_quant_weights as generate_quant_weights
import hllm.eda.mapping_weights as generate_mapping_weights
generate_quant_weights.run(config=config)
generate_mapping_weights.run(config=config)
def run_verify(config: CFG):
import hllm.verify.generate_activation as generate_activation
import hllm.verify.verify_activation as verify_activation
config.output_dir = "outputs-qwen/verify"
generate_activation.run(config=config)
verify_activation.run(config=config)
def batch_run(config: CFG):
for weights in config.run_weights_batch:
config.run_weights = weights
run_origin(config)
run_optimized(config)
run_verify(config)
if __name__ == "__main__":
config = CFG()
# run_weights_preprocess(config)
# run_origin(config)
# run_optimized(config)
# run_verify()
batch_run(config)
\ No newline at end of file
from setuptools import setup, find_packages, Extension
import pybind11
ext_modules = [
Extension(
"hllm.optimized.turbo_optimize_hn", # 注意这里的模块路径要匹配包结构
["hllm/optimized/turbo_optimize_hn.cpp"],
include_dirs=[pybind11.get_include()],
language="c++",
extra_compile_args=["-std=c++11", "-fPIC", "-O3"],
extra_link_args=["-static-libstdc++"],
),
]
setup(
name="hllm",
version="0.1.0",
packages=find_packages(),
install_requires=[
"pybind11>=2.6.0",
],
ext_modules=ext_modules,
)
// 文件及模块命名FSM_tp{weight_type}_gp{i}, 在该文件中寄存器宽度等变量尽可能保持一致以简化代码(开销很小,例如结果的位宽就按最宽的WT来)
//(* keep_hierarchy = "yes" *)
module FSM #(
parameter H = 1536, // hidden layer dim, 1536
parameter L = 52, // layer num, 52
parameter VN = 512, // vector num, 1536 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 17, // DG, related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // DG, log2(SCW)
parameter TTW = 26 // DG, MAC output total width, SCW + WP + 4
) (
input clk,
input valid,
input fsm_rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_rstn;
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
// 不用加位宽,直接左移,高位舍弃
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
reg [SCWB : 0] idx;
reg [2 : 0] state;
reg CST_LOW;
genvar j;
integer i;
Mid_wrapper_tp_k_gp_0 #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(fsm_rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
state <= 0;
idx <= 0;
tree_rstn <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
end else begin
if (state == 0) begin
idx <= 0;
tree_rstn <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
if (valid == 1) begin
state <= 1;
end
end
if (state == 1) begin
tree_rstn <= 1;
if (idx == 0) begin //此时tree_rstn应该还是0
TM_sel <= 8'b00000001;
end else begin
if (idx == 1) begin
TM_sel <= 8'b00000010;
end
if (idx == 2) begin
TM_sel <= 8'b00000100;
end
if (idx == 3) begin
TM_sel <= 8'b00001000;
end
if (idx == 4) begin
TM_sel <= 8'b00010000;
end
if (idx == 5) begin
TM_sel <= 8'b00100000;
end
if (idx == 6) begin
TM_sel <= 8'b01000000;
end
if (idx == 7) begin
TM_sel <= 8'b10000000;
end
if (idx > 7) begin
TM_sel <= 8'b10000000;
end // signed extension
if (idx == SCW) begin
state <= 2;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
idx <= idx + 1;
end
if (state == 2) begin
for (
i = 0; i < VN; i = i + 1
) begin // 有符号数加法,补的0直接被截断了,没关系
WT_result_v0[i] <= {1'b0, final_S_v0[i]} + {final_C_v0[i], 1'b0};
WT_result_v1[i] <= {1'b0, final_S_v1[i]} + {final_C_v1[i], 1'b0};
WT_result_v2[i] <= {1'b0, final_S_v2[i]} + {final_C_v2[i], 1'b0};
WT_result_v3[i] <= {1'b0, final_S_v3[i]} + {final_C_v3[i], 1'b0};
WT_result_v4[i] <= {1'b0, final_S_v4[i]} + {final_C_v4[i], 1'b0};
WT_result_v5[i] <= {1'b0, final_S_v5[i]} + {final_C_v5[i], 1'b0};
WT_result_v6[i] <= {1'b0, final_S_v6[i]} + {final_C_v6[i], 1'b0};
WT_result_v7[i] <= {1'b0, final_S_v7[i]} + {final_C_v7[i], 1'b0};
WT_result_v8[i] <= {1'b0, final_S_v8[i]} + {final_C_v8[i], 1'b0};
WT_result_v9[i] <= {1'b0, final_S_v9[i]} + {final_C_v9[i], 1'b0};
WT_result_v10[i] <= {1'b0, final_S_v10[i]} + {final_C_v10[i], 1'b0};
WT_result_v11[i] <= {1'b0, final_S_v11[i]} + {final_C_v11[i], 1'b0};
WT_result_v12[i] <= {1'b0, final_S_v12[i]} + {final_C_v12[i], 1'b0};
WT_result_v13[i] <= {1'b0, final_S_v13[i]} + {final_C_v13[i], 1'b0};
end
idx <= 0;
state <= 3;
tree_rstn <= 0;
end
if (state == 3) begin
// MAC adder
state <= 4;
end
if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
state <= 5;
result_valid <= 1;
end
if (state == 5) begin
// 输出及其他握手信号,待用
idx <= 0;
state <= 0;
tree_rstn <= 0;
end
end
end
endmodule
`define WW_0 8 // wallance tree width
`define WW_1 8 // wallance tree width
`define SCW 9 // S and Cout width, can be cauculated by max WW
`define SCWB 4 //log2(SCW), related to fsm state and idx count
`define H 16 // hidden layer dim, 1536
`define L 5 // layer num, 52
`define VN 5 // vector num, 1536 or 3840(in FFN)
`define WP 4 // weight precision
`define WS 16 // weight state ( = 2 ^ WP )
`define AP 8 // activation precision 因为涉及到状态数量啥的,FSM里面分开写了,如果要变动的话,除了改这里还得进去改
`define TTW 16 // total width, related to SCW and WS, and the behavior of accumulation(int or unsigned int)
//TODO现有代码所有的vector共用一个HN,需编写脚本生成可以具有不同HN的sub_wrapper的Verilog代码。
module HN #(
parameter WW_0 = 8,
parameter WW_1 = 8,
parameter H = 16,
parameter L = 5
) (
input [H - 1 : 0] HN_in,
output [WW_0 - 1 : 0] HN_out_0[0 : L - 1],
output [WW_1 - 1 : 0] HN_out_1[0 : L - 1]
);
// TODO ,目前随便写了一个,逻辑是:第0层:HN_out_0是偶数项,HN_out_1是奇数项;第1层相反。
// assign HN_out_0[0][0] = HN_in[0];
// assign HN_out_0[0][1] = HN_in[2];
// assign HN_out_0[0][2] = HN_in[4];
// assign HN_out_0[0][3] = HN_in[6];
// assign HN_out_0[0][4] = HN_in[8];
// assign HN_out_0[0][5] = HN_in[10];
// assign HN_out_0[0][6] = HN_in[12];
// assign HN_out_0[0][7] = HN_in[14];
// assign HN_out_1[0][0] = HN_in[1];
// assign HN_out_1[0][1] = HN_in[3];
// assign HN_out_1[0][2] = HN_in[5];
// assign HN_out_1[0][3] = HN_in[7];
// assign HN_out_1[0][4] = HN_in[9];
// assign HN_out_1[0][5] = HN_in[11];
// assign HN_out_1[0][6] = HN_in[13];
// assign HN_out_1[0][7] = HN_in[15];
// assign HN_out_0[1][0] = HN_in[1];
// assign HN_out_0[1][1] = HN_in[3];
// assign HN_out_0[1][2] = HN_in[5];
// assign HN_out_0[1][3] = HN_in[7];
// assign HN_out_0[1][4] = HN_in[9];
// assign HN_out_0[1][5] = HN_in[11];
// assign HN_out_0[1][6] = HN_in[13];
// assign HN_out_0[1][7] = HN_in[15];
// assign HN_out_1[1][0] = HN_in[0];
// assign HN_out_1[1][1] = HN_in[2];
// assign HN_out_1[1][2] = HN_in[4];
// assign HN_out_1[1][3] = HN_in[6];
// assign HN_out_1[1][4] = HN_in[8];
// assign HN_out_1[1][5] = HN_in[10];
// assign HN_out_1[1][6] = HN_in[12];
// assign HN_out_1[1][7] = HN_in[14];
endmodule
module Top_mux // abbr. TM
#(
parameter H = 16,
parameter AP = 8
) (
input [AP - 1 : 0] TM_sel,
input [AP - 1 : 0] TM_in [H - 1 : 0],
output [ H - 1 : 0] TM_out
);
wire [H - 1 : 0] TM_in_T[AP - 1 : 0];
wire [H - 1 : 0] TM_masked[AP - 1 : 0];
wire [AP - 1 : 0] TM_masked_T[H - 1 : 0];
genvar j;
genvar k;
generate
for (k = 0; k < H; k = k + 1) begin : TM_transpose_loop_0_out
for (j = 0; j < AP; j = j + 1) begin : TM_transpose_loop_0_in
assign TM_in_T[j][k] = TM_in[k][j];
end
end
endgenerate
genvar i;
generate
for (i = 0; i < AP; i = i + 1) begin : TM_select_loop
assign TM_masked[i] = TM_in_T[i] & {H{TM_sel[i]}};
end
endgenerate
genvar m;
genvar n;
generate
for (m = 0; m < H; m = m + 1) begin : TM_transpose_loop_1_out
for (n = 0; n < AP; n = n + 1) begin : TM_transpose_loop_1_in
assign TM_masked_T[m][n] = TM_masked[n][m];
end
end
endgenerate
genvar p;
generate
for (p = 0; p < H; p = p + 1) begin : TM_reduce_or_loop
assign TM_out[p] = |(TM_masked_T[p]);
end
endgenerate
endmodule
module Layer_mux // abbr. LM
#(
parameter L = 4,
parameter WW_0 = 8,
parameter WW_1 = 8
) (
input [L - 1 : 0] LM_sel,
input [WW_0 - 1 : 0] LM_in_0[L - 1 : 0],
input [WW_1 - 1 : 0] LM_in_1[L - 1 : 0],
output [WW_0 - 1 : 0] LM_out_0,
output [WW_1 - 1 : 0] LM_out_1
);
// function: if LM_sel[i] == 1, LM_out_0 = LM_in_0[i]
wire [WW_0 - 1 : 0] LM_in_0_masked[L - 1 : 0];
wire [WW_1 - 1 : 0] LM_in_1_masked[L - 1 : 0];
wire [L - 1 : 0] LM_in_0_masked_T[WW_0 - 1 : 0];
wire [L - 1 : 0] LM_in_1_masked_T[WW_1 - 1 : 0];
genvar i;
generate
for (i = 0; i < L; i = i + 1) begin : LM_select_loop
assign LM_in_0_masked[i] = LM_in_0[i] & {WW_0{LM_sel[i]}};
assign LM_in_1_masked[i] = LM_in_1[i] & {WW_1{LM_sel[i]}};
end
endgenerate
genvar j;
genvar k;
generate
for (k = 0; k < L; k = k + 1) begin : LM_transpose_loop_out
for (j = 0; j < WW_0; j = j + 1) begin : LM_transpose_loop_in_0
assign LM_in_0_masked_T[j][k] = LM_in_0_masked[k][j];
end
for (j = 0; j < WW_1; j = j + 1) begin : LM_transpose_loop_in_1
assign LM_in_1_masked_T[j][k] = LM_in_1_masked[k][j];
end
end
endgenerate
genvar m;
generate
for (m = 0; m < WW_0; m = m + 1) begin : LM_reduce_or_loop_0
assign LM_out_0[m] = |(LM_in_0_masked_T[m]);
end
for (m = 0; m < WW_0; m = m + 1) begin : LM_reduce_or_loop_1
assign LM_out_1[m] = |(LM_in_1_masked_T[m]);
end
endgenerate
endmodule
module WT_group #(
parameter WW_0 = 8,
parameter WW_1 = 8
) (
input clk,
input tree_rstn,
input valid,
input [WW_0 - 1 : 0] WT_0_in,
input [WW_1 - 1 : 0] WT_1_in,
output WT_0_out_S,
output WT_0_out_C,
output WT_1_out_S,
output WT_1_out_C
);
SerialWallaceTree8Input serial_wallace_tree_0 (
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_0_in),
.out_S(WT_0_out_S),
.out_Cout(WT_0_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_1 (
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_1_in),
.out_S(WT_1_out_S),
.out_Cout(WT_1_out_C)
);
endmodule
module Sub_wrapper // abbr. SW
#(
parameter WW_0 = 8,
parameter WW_1 = 8,
parameter H = 16,
parameter L = 5
) (
input clk,
input tree_rstn,
input valid,
input [L - 1 : 0] LM_sel,
input [H - 1 : 0] SW_in,
output WT_0_out_S,
output WT_0_out_C,
output WT_1_out_S,
output WT_1_out_C
);
wire [WW_0 - 1 : 0] HN_out_0[0 : L - 1];
wire [WW_1 - 1 : 0] HN_out_1[0 : L - 1];
HN #(
.WW_0(WW_0),
.WW_1(WW_1),
.H(H),
.L(L)
) hn (
.HN_in(SW_in),
.HN_out_0(HN_out_0),
.HN_out_1(HN_out_1)
);
wire [WW_0 - 1 : 0] LM_out_0;
wire [WW_1 - 1 : 0] LM_out_1;
Layer_mux #(
.L(L),
.WW_0(WW_0),
.WW_1(WW_1)
) layer_mux (
.LM_sel (LM_sel),
.LM_in_0 (HN_out_0),
.LM_in_1 (HN_out_1),
.LM_out_0(LM_out_0),
.LM_out_1(LM_out_1)
);
WT_group #(
.WW_0(WW_0),
.WW_1(WW_1)
) wt_group (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.WT_0_in(LM_out_0),
.WT_1_in(LM_out_1),
.WT_0_out_S(WT_0_out_S),
.WT_0_out_C(WT_0_out_C),
.WT_1_out_S(WT_1_out_S),
.WT_1_out_C(WT_1_out_C)
);
endmodule
module MAC #(
parameter W_1 = 8, // input 1 width
parameter W_2 = 4, // input 2 width
parameter W_O = 16 // output width
) (
input clk,
input rstn,
input signed [W_1 - 1 : 0] MAC_in_1,
input signed [W_2 - 1 : 0] MAC_in_2,
output reg signed [W_O - 1 : 0] MAC_out
);
wire signed [W_1 + W_2 - 1 : 0] multi;
always @(negedge rstn or posedge clk) begin
if (!rstn) begin
MAC_out <= 0;
end else begin
MAC_out <= MAC_out + multi;
end
end
assign multi = MAC_in_1 * MAC_in_2;
endmodule
module top_FSM #(
parameter WW_0 = 8, // wallance tree width
parameter WW_1 = 8, // wallance tree width
parameter H = 16, // hidden layer dim, 1536
parameter L = 5, // layer num, 52
parameter VN = 5, // vector num, 1536 or 3840(in FFN)
parameter WP = 4, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 11,
parameter SCWB = 4,
parameter TTW = 16
) (
input clk,
input valid,
input fsm_rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_0_out_S[VN - 1 : 0];
wire WT_0_out_C[VN - 1 : 0];
wire WT_1_out_S[VN - 1 : 0];
wire WT_1_out_C[VN - 1 : 0];
reg tree_rstn, mac_rstn;
reg [SCW - 1 : 0] final_S_0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_1[VN - 1 : 0];
reg [SCW + 1 : 0] MAC_in_1[VN - 1 : 0];
reg [WP - 1 : 0] MAC_in_2[VN - 1 : 0];
wire [TTW : 0] MAC_out[VN - 1 : 0];
reg [SCW + 1 : 0] WT_result_0[VN - 1 : 0]; // increase 2 bit after shift and add
reg [SCW + 1 : 0] WT_result_1[VN - 1 : 0]; // increase 2 bit after shift and add
// reg [TTW : 0] WT_result_acc[VN - 1 : 0]; // accumulated WT_result
reg [SCWB : 0] idx;
reg [2 : 0] state;
genvar j;
integer i;
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
Sub_wrapper #(
.WW_0(WW_0),
.WW_1(WW_1),
.H(H),
.L(L)
) sub_wrapper (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_0_out_S[j]),
.WT_0_out_C(WT_0_out_C[j]),
.WT_1_out_S(WT_1_out_S[j]),
.WT_1_out_C(WT_1_out_C[j])
);
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW) // output width
) mac (
.clk(clk),
.rstn(mac_rstn),
.MAC_in_1(MAC_in_1[j]),
.MAC_in_2(MAC_in_2[j]),
.MAC_out(MAC_out[j])
);
end
endgenerate
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
state <= 0;
idx <= 0;
tree_rstn <= 0;
mac_rstn <= 0;
result_valid <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
MAC_in_1[i] <= 0;
MAC_in_2[i] <= 0;
end
end else begin
if (state == 0) begin
idx <= 0;
tree_rstn <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_0[i] <= 0;
final_C_0[i] <= 0;
final_S_1[i] <= 0;
final_C_1[i] <= 0;
WT_result_0[i] <= 0;
WT_result_1[i] <= 0;
WT_result_acc[i] <= 0;
end
if (valid == 1) begin
state <= 1;
end
end
if (state == 1) begin
tree_rstn <= 1;
if (idx == 0) begin //此时tree_rstn应该还是0
TM_sel <= 8'b00000001;
end else begin
if (idx == 1) begin
TM_sel <= 8'b00000010;
end
if (idx == 2) begin
TM_sel <= 8'b00000100;
end
if (idx == 3) begin
TM_sel <= 8'b00001000;
end
if (idx == 4) begin
TM_sel <= 8'b00010000;
end
if (idx == 5) begin
TM_sel <= 8'b00100000;
end
if (idx == 6) begin
TM_sel <= 8'b01000000;
end
if (idx == 7) begin
TM_sel <= 8'b10000000;
end
if (idx > 7) begin
TM_sel <= 8'b00000000;
end
if (idx == SCW) begin
state <= 2;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_0[i][idx-1] <= WT_0_out_S[i];
final_C_0[i][idx-1] <= WT_0_out_C[i];
final_S_1[i][idx-1] <= WT_1_out_S[i];
final_C_1[i][idx-1] <= WT_1_out_C[i];
end
end
idx <= idx + 1;
end
if (state == 2) begin
for (i = 0; i < VN; i = i + 1) begin : WT_result_loop
WT_result_0[i] <= final_S_0[i] + (final_C_0[i] << 1);
WT_result_1[i] <= final_S_1[i] + (final_C_1[i] << 1);
end
idx <= 0;
state <= 3;
tree_rstn <= 0;
mac_rstn <= 1;
end
if (state == 3) begin
if (idx == 0) begin
for (i = 0; i < VN; i = i + 1) begin : WT_acc_loop_0
MAC_in_1[i] <= WT_result_0[i];
MAC_in_2[i] <= 1; // 这个地方其实可以只用一个寄存器而不是VN个
end
idx <= idx + 1;
end
if (idx == 1) begin
for (i = 0; i < VN; i = i + 1) begin : WT_acc_loop_1
MAC_in_1[i] <= WT_result_1[i];
MAC_in_2[i] <= -1;
end
idx <= idx + 1;
end
if (idx == 2) begin
idx <= idx + 1;
// 需要空闲一个周期等待MAC_out更新
end
if (idx == 3) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
idx <= idx + 1;
state <= 4;
result_valid <= 1;
// 放这里而不放后面,保证MAC_out不出现错误值
mac_rstn <= 0;
end
end
if (state == 4) begin
// 输出及其他握手信号,待用
idx <= 0;
state <= 0;
tree_rstn <= 0;
end
end
end
endmodule
module Mid_wrappers_tp_k #(
parameter H = 16, //这些数值无所谓
parameter L = 2,
parameter VN = 16
) (
input clk,
input tree_rstn,
input valid,
input CST_LOW,
input [L - 1 : 0] LM_sel,
input [H - 1 : 0] SW_in,
output WT_v0_out_S[VN - 1 : 0],
output WT_v0_out_C[VN - 1 : 0],
output WT_v1_out_S[VN - 1 : 0],
output WT_v1_out_C[VN - 1 : 0],
output WT_v2_out_S[VN - 1 : 0],
output WT_v2_out_C[VN - 1 : 0],
output WT_v3_out_S[VN - 1 : 0],
output WT_v3_out_C[VN - 1 : 0],
output WT_v4_out_S[VN - 1 : 0],
output WT_v4_out_C[VN - 1 : 0],
output WT_v5_out_S[VN - 1 : 0],
output WT_v5_out_C[VN - 1 : 0],
output WT_v6_out_S[VN - 1 : 0],
output WT_v6_out_C[VN - 1 : 0],
output WT_v7_out_S[VN - 1 : 0],
output WT_v7_out_C[VN - 1 : 0],
output WT_v8_out_S[VN - 1 : 0],
output WT_v8_out_C[VN - 1 : 0],
output WT_v9_out_S[VN - 1 : 0],
output WT_v9_out_C[VN - 1 : 0],
output WT_v10_out_S[VN - 1 : 0],
output WT_v10_out_C[VN - 1 : 0],
output WT_v11_out_S[VN - 1 : 0],
output WT_v11_out_C[VN - 1 : 0],
output WT_v12_out_S[VN - 1 : 0],
output WT_v12_out_C[VN - 1 : 0],
output WT_v13_out_S[VN - 1 : 0],
output WT_v13_out_C[VN - 1 : 0]
);
Sub_wrapper_tp_k_vc_0 sub_wrapper_0 //名称手动迭代下,还有下面的索引
(
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(SW_in),
.WT_0_out_S(WT_v0_out_S[0]),
.WT_0_out_C(WT_v0_out_C[0]),
.WT_1_out_S(WT_v1_out_S[0]),
.WT_1_out_C(WT_v1_out_C[0]),
.WT_2_out_S(WT_v2_out_S[0]),
.WT_2_out_C(WT_v2_out_C[0]),
.WT_3_out_S(WT_v3_out_S[0]),
.WT_3_out_C(WT_v3_out_C[0]),
.WT_4_out_S(WT_v4_out_S[0]),
.WT_4_out_C(WT_v4_out_C[0]),
.WT_5_out_S(WT_v5_out_S[0]),
.WT_5_out_C(WT_v5_out_C[0]),
.WT_6_out_S(WT_v6_out_S[0]),
.WT_6_out_C(WT_v6_out_C[0]),
.WT_7_out_S(WT_v7_out_S[0]),
.WT_7_out_C(WT_v7_out_C[0]),
.WT_8_out_S(WT_v8_out_S[0]),
.WT_8_out_C(WT_v8_out_C[0]),
.WT_9_out_S(WT_v9_out_S[0]),
.WT_9_out_C(WT_v9_out_C[0]),
.WT_10_out_S(WT_v10_out_S[0]),
.WT_10_out_C(WT_v10_out_C[0]),
.WT_11_out_S(WT_v11_out_S[0]),
.WT_11_out_C(WT_v11_out_C[0]),
.WT_12_out_S(WT_v12_out_S[0]),
.WT_12_out_C(WT_v12_out_C[0]),
.WT_13_out_S(WT_v13_out_S[0]),
.WT_13_out_C(WT_v13_out_C[0])
);
Sub_wrapper_tp_k_vc_1 sub_wrapper_1 (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(SW_in),
.WT_0_out_S(WT_v0_out_S[1]),
.WT_0_out_C(WT_v0_out_C[1]),
.WT_1_out_S(WT_v1_out_S[1]),
.WT_1_out_C(WT_v1_out_C[1]),
.WT_2_out_S(WT_v2_out_S[1]),
.WT_2_out_C(WT_v2_out_C[1]),
.WT_3_out_S(WT_v3_out_S[1]),
.WT_3_out_C(WT_v3_out_C[1]),
.WT_4_out_S(WT_v4_out_S[1]),
.WT_4_out_C(WT_v4_out_C[1]),
.WT_5_out_S(WT_v5_out_S[1]),
.WT_5_out_C(WT_v5_out_C[1]),
.WT_6_out_S(WT_v6_out_S[1]),
.WT_6_out_C(WT_v6_out_C[1]),
.WT_7_out_S(WT_v7_out_S[1]),
.WT_7_out_C(WT_v7_out_C[1]),
.WT_8_out_S(WT_v8_out_S[1]),
.WT_8_out_C(WT_v8_out_C[1]),
.WT_9_out_S(WT_v9_out_S[1]),
.WT_9_out_C(WT_v9_out_C[1]),
.WT_10_out_S(WT_v10_out_S[1]),
.WT_10_out_C(WT_v10_out_C[1]),
.WT_11_out_S(WT_v11_out_S[1]),
.WT_11_out_C(WT_v11_out_C[1]),
.WT_12_out_S(WT_v12_out_S[1]),
.WT_12_out_C(WT_v12_out_C[1]),
.WT_13_out_S(WT_v13_out_S[1]),
.WT_13_out_C(WT_v13_out_C[1])
);
endmodule
module FSM_tp_k_gp_0_0127 #(
parameter H = 896, // hidden layer dim, 1536
parameter L = 24, // layer num, 52
parameter VN = 32, // vector num, 1536 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 18, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 32 // MAC output total width, SCW + WP + 4
) (
input clk,
input valid,
input fsm_rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_valid; // carry reg clear if not valid
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v0_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v1_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v2_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v3_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v4_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v5_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v6_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v7_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v8_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v9_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v10_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v11_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v12_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v13_wire[VN - 1 : 0];
reg [7 : 0] state_idx; // MSB 3: state; LSB 5: idx
reg [7 : 0] next_state_idx;
wire [2 : 0] state;
wire [4 : 0] idx;
assign state = state_idx[7 : 5];
assign idx = state_idx[4 : 0];
reg CST_LOW;
genvar j;
integer i;
Mid_wrapper_tp_k_gp_0 #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(fsm_rstn),
.valid(tree_valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(fsm_rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
genvar k;
generate
for (k = 0; k < VN; k = k + 1) begin
assign WT_result_v0_wire[k] = {1'b0, final_S_v0[k]} + {final_C_v0[k], 1'b0};
assign WT_result_v1_wire[k] = {1'b0, final_S_v1[k]} + {final_C_v1[k], 1'b0};
assign WT_result_v2_wire[k] = {1'b0, final_S_v2[k]} + {final_C_v2[k], 1'b0};
assign WT_result_v3_wire[k] = {1'b0, final_S_v3[k]} + {final_C_v3[k], 1'b0};
assign WT_result_v4_wire[k] = {1'b0, final_S_v4[k]} + {final_C_v4[k], 1'b0};
assign WT_result_v5_wire[k] = {1'b0, final_S_v5[k]} + {final_C_v5[k], 1'b0};
assign WT_result_v6_wire[k] = {1'b0, final_S_v6[k]} + {final_C_v6[k], 1'b0};
assign WT_result_v7_wire[k] = {1'b0, final_S_v7[k]} + {final_C_v7[k], 1'b0};
assign WT_result_v8_wire[k] = {1'b0, final_S_v8[k]} + {final_C_v8[k], 1'b0};
assign WT_result_v9_wire[k] = {1'b0, final_S_v9[k]} + {final_C_v9[k], 1'b0};
assign WT_result_v10_wire[k] = {1'b0, final_S_v10[k]} + {final_C_v10[k], 1'b0};
assign WT_result_v11_wire[k] = {1'b0, final_S_v11[k]} + {final_C_v11[k], 1'b0};
assign WT_result_v12_wire[k] = {1'b0, final_S_v12[k]} + {final_C_v12[k], 1'b0};
assign WT_result_v13_wire[k] = {1'b0, final_S_v13[k]} + {final_C_v13[k], 1'b0};
end
endgenerate
// fsm next state generation
always @(state_idx, valid) begin
case(state)
3'b000: begin
if (valid == 1) next_state_idx = 8'b00100000;
else next_state_idx = 0;
end
3'b001: begin
if (idx == SCW) next_state_idx = 8'b01000000;
else next_state_idx = state_idx + 1;
end
3'b010: next_state_idx = 8'b01100000;
3'b011: next_state_idx = 8'b10000000;
3'b100: next_state_idx = 0;
default: next_state_idx = 0;
endcase
end
// fsm state transfer
always @(posedge clk) begin
if(!fsm_rstn)
state_idx <= 0;
else
state_idx <= next_state_idx;
end
//fsm output
always @(posedge clk or negedge fsm_rstn) begin
if (!fsm_rstn) begin
//state <= 0;
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
end else begin
if (state == 0) begin
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
/*
if (valid == 1) begin
state <= 1;
end
else begin
state <= 0;
end
*/
end
else if (state == 1) begin
tree_valid <= 1;
if (idx == 0) begin
TM_sel <= 8'b00000001;
end else begin
case (idx)
1: TM_sel <= 8'b00000010;
2: TM_sel <= 8'b00000100;
3: TM_sel <= 8'b00001000;
4: TM_sel <= 8'b00010000;
5: TM_sel <= 8'b00100000;
6: TM_sel <= 8'b01000000;
7: TM_sel <= 8'b10000000;
default: TM_sel <= 8'b10000000; // signed extension
endcase
//if (idx == SCW) begin
// state <= 2;
//end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
//idx <= idx + 1;
end
else if (state == 2) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_v0[i] <= WT_result_v0_wire[i];
WT_result_v1[i] <= WT_result_v1_wire[i];
WT_result_v2[i] <= WT_result_v2_wire[i];
WT_result_v3[i] <= WT_result_v3_wire[i];
WT_result_v4[i] <= WT_result_v4_wire[i];
WT_result_v5[i] <= WT_result_v5_wire[i];
WT_result_v6[i] <= WT_result_v6_wire[i];
WT_result_v7[i] <= WT_result_v7_wire[i];
WT_result_v8[i] <= WT_result_v8_wire[i];
WT_result_v9[i] <= WT_result_v9_wire[i];
WT_result_v10[i] <= WT_result_v10_wire[i];
WT_result_v11[i] <= WT_result_v11_wire[i];
WT_result_v12[i] <= WT_result_v12_wire[i];
WT_result_v13[i] <= WT_result_v13_wire[i];
end
//idx <= 0;
//state <= 3;
tree_valid <= 0;
end
else if (state == 3) begin
// MAC adder
//state <= 4;
end
else if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
//state <= 5;
result_valid <= 1;
end
//if (state == 5) begin
// ouput, reserved cycle
//idx <= 0;
//state <= 0;
//tree_valid <= 0;
//end
end
end
endmodule
// WW_v0: WT width for value 0: -6 : 8
// WW_v1: WT width for value 1: -4 : 8
// WW_v2: WT width for value 2: -3 : 8
// WW_v3: WT width for value 3: -2 : 8
// WW_v4: WT width for value 4: -1.5 : 8
// WW_v5: WT width for value 5: -1 : 8
// WW_v6: WT width for value 6: -0.5 : 8
// WW_v7: WT width for value 7: 0.5 : 8
// WW_v8: WT width for value 8: 1 : 8
// WW_v9: WT width for value 9: 1.5 : 8
// WW_v10: WT width for value 10: 2 : 8
// WW_v11: WT width for value 11: 3 : 8
// WW_v12: WT width for value 12: 4 : 8
// WW_v13: WT width for value 13: 6 : 8
// 对于宽度WW参数传递的设计:以上注释为了可读性要在代码中保留,用实际值代替。模块内部参数传递用parameter,SW模块定
// 义时parameter的默认值要设成真实值,其他模块无所谓。顶层例化的时候不传递任何parameter(为了代码的复用性)
// 该文件名设置为SW,该文件包含的模块:HN,LM,WT_group,SW。也可以酌情把HN分出来。
// 模块及文件名命名规则范例:HN_tp{weight_type}_vc{id} tp:type, vc:vector, 加一个类似的前缀是为了检索或替换方便,类似于WW_v0中v表示value
// 生成时HN可以放最后面,太多了省的每次都得翻,其他代码量小的模块放前面
// to zhengzifu: DG : parameters to be dynamically generated
// 不用`define,防止多文件冲突,仅用parameter显式传参
module HN
#(
parameter WW_v0 = 8,
parameter WW_v1 = 8,
parameter WW_v2 = 8,
parameter WW_v3 = 8,
parameter WW_v4 = 8,
parameter WW_v5 = 8,
parameter WW_v6 = 8,
parameter WW_v7 = 8,
parameter WW_v8 = 8,
parameter WW_v9 = 8,
parameter WW_v10 = 8,
parameter WW_v11 = 8,
parameter WW_v12 = 8,
parameter WW_v13 = 8,
parameter H = 16,
parameter L = 5
)
(
input [H - 1 : 0] HN_in,
output [WW_v0 - 1 : 0] HN_out_v0[L - 1 : 0],
output [WW_v1 - 1 : 0] HN_out_v1[L - 1 : 0],
output [WW_v2 - 1 : 0] HN_out_v2[L - 1 : 0],
output [WW_v3 - 1 : 0] HN_out_v3[L - 1 : 0],
output [WW_v4 - 1 : 0] HN_out_v4[L - 1 : 0],
output [WW_v5 - 1 : 0] HN_out_v5[L - 1 : 0],
output [WW_v6 - 1 : 0] HN_out_v6[L - 1 : 0],
output [WW_v7 - 1 : 0] HN_out_v7[L - 1 : 0],
output [WW_v8 - 1 : 0] HN_out_v8[L - 1 : 0],
output [WW_v9 - 1 : 0] HN_out_v9[L - 1 : 0],
output [WW_v10 - 1 : 0] HN_out_v10[L - 1 : 0],
output [WW_v11 - 1 : 0] HN_out_v11[L - 1 : 0],
output [WW_v12 - 1 : 0] HN_out_v12[L - 1 : 0],
output [WW_v13 - 1 : 0] HN_out_v13[L - 1 : 0]
);
// TODO
assign HN_out_v0[0][0] = HN_in[1];
assign HN_out_v0[0][1] = HN_in[9];
endmodule
module Layer_mux // abbr. LM
#(
parameter L = 4,
parameter WW_v0 = 8,
parameter WW_v1 = 8,
parameter WW_v2 = 8,
parameter WW_v3 = 8,
parameter WW_v4 = 8,
parameter WW_v5 = 8,
parameter WW_v6 = 8,
parameter WW_v7 = 8,
parameter WW_v8 = 8,
parameter WW_v9 = 8,
parameter WW_v10 = 8,
parameter WW_v11 = 8,
parameter WW_v12 = 8,
parameter WW_v13 = 8
)
(
input [L - 1 : 0] LM_sel,
input [WW_v0 - 1 : 0] LM_in_v0[L - 1 : 0],
input [WW_v1 - 1 : 0] LM_in_v1[L - 1 : 0],
input [WW_v2 - 1 : 0] LM_in_v2[L - 1 : 0],
input [WW_v3 - 1 : 0] LM_in_v3[L - 1 : 0],
input [WW_v4 - 1 : 0] LM_in_v4[L - 1 : 0],
input [WW_v5 - 1 : 0] LM_in_v5[L - 1 : 0],
input [WW_v6 - 1 : 0] LM_in_v6[L - 1 : 0],
input [WW_v7 - 1 : 0] LM_in_v7[L - 1 : 0],
input [WW_v8 - 1 : 0] LM_in_v8[L - 1 : 0],
input [WW_v9 - 1 : 0] LM_in_v9[L - 1 : 0],
input [WW_v10 - 1 : 0] LM_in_v10[L - 1 : 0],
input [WW_v11 - 1 : 0] LM_in_v11[L - 1 : 0],
input [WW_v12 - 1 : 0] LM_in_v12[L - 1 : 0],
input [WW_v13 - 1 : 0] LM_in_v13[L - 1 : 0],
output [WW_v0 - 1 : 0] LM_out_v0,
output [WW_v1 - 1 : 0] LM_out_v1,
output [WW_v2 - 1 : 0] LM_out_v2,
output [WW_v3 - 1 : 0] LM_out_v3,
output [WW_v4 - 1 : 0] LM_out_v4,
output [WW_v5 - 1 : 0] LM_out_v5,
output [WW_v6 - 1 : 0] LM_out_v6,
output [WW_v7 - 1 : 0] LM_out_v7,
output [WW_v8 - 1 : 0] LM_out_v8,
output [WW_v9 - 1 : 0] LM_out_v9,
output [WW_v10 - 1 : 0] LM_out_v10,
output [WW_v11 - 1 : 0] LM_out_v11,
output [WW_v12 - 1 : 0] LM_out_v12,
output [WW_v13 - 1 : 0] LM_out_v13
);
// function: if LM_sel[i] == 1, LM_out_0 = LM_in_0[i]
wire [WW_v0 - 1 : 0] LM_in_v0_masked[L - 1 : 0];
wire [WW_v1 - 1 : 0] LM_in_v1_masked[L - 1 : 0];
wire [WW_v2 - 1 : 0] LM_in_v2_masked[L - 1 : 0];
wire [WW_v3 - 1 : 0] LM_in_v3_masked[L - 1 : 0];
wire [WW_v4 - 1 : 0] LM_in_v4_masked[L - 1 : 0];
wire [WW_v5 - 1 : 0] LM_in_v5_masked[L - 1 : 0];
wire [WW_v6 - 1 : 0] LM_in_v6_masked[L - 1 : 0];
wire [WW_v7 - 1 : 0] LM_in_v7_masked[L - 1 : 0];
wire [WW_v8 - 1 : 0] LM_in_v8_masked[L - 1 : 0];
wire [WW_v9 - 1 : 0] LM_in_v9_masked[L - 1 : 0];
wire [WW_v10 - 1 : 0] LM_in_v10_masked[L - 1 : 0];
wire [WW_v11 - 1 : 0] LM_in_v11_masked[L - 1 : 0];
wire [WW_v12 - 1 : 0] LM_in_v12_masked[L - 1 : 0];
wire [WW_v13 - 1 : 0] LM_in_v13_masked[L - 1 : 0];
wire [L - 1 : 0] LM_in_v0_masked_T[WW_v0 - 1 : 0];
wire [L - 1 : 0] LM_in_v1_masked_T[WW_v1 - 1 : 0];
wire [L - 1 : 0] LM_in_v2_masked_T[WW_v2 - 1 : 0];
wire [L - 1 : 0] LM_in_v3_masked_T[WW_v3 - 1 : 0];
wire [L - 1 : 0] LM_in_v4_masked_T[WW_v4 - 1 : 0];
wire [L - 1 : 0] LM_in_v5_masked_T[WW_v5 - 1 : 0];
wire [L - 1 : 0] LM_in_v6_masked_T[WW_v6 - 1 : 0];
wire [L - 1 : 0] LM_in_v7_masked_T[WW_v7 - 1 : 0];
wire [L - 1 : 0] LM_in_v8_masked_T[WW_v8 - 1 : 0];
wire [L - 1 : 0] LM_in_v9_masked_T[WW_v9 - 1 : 0];
wire [L - 1 : 0] LM_in_v10_masked_T[WW_v10 - 1 : 0];
wire [L - 1 : 0] LM_in_v11_masked_T[WW_v11 - 1 : 0];
wire [L - 1 : 0] LM_in_v12_masked_T[WW_v12 - 1 : 0];
wire [L - 1 : 0] LM_in_v13_masked_T[WW_v13 - 1 : 0];
genvar i;
generate
for (i = 0; i < L; i = i + 1) begin : LM_select_loop
assign LM_in_v0_masked[i] = LM_in_v0[i] & {WW_v0{LM_sel[i]}};
assign LM_in_v1_masked[i] = LM_in_v1[i] & {WW_v1{LM_sel[i]}};
assign LM_in_v2_masked[i] = LM_in_v2[i] & {WW_v2{LM_sel[i]}};
assign LM_in_v3_masked[i] = LM_in_v3[i] & {WW_v3{LM_sel[i]}};
assign LM_in_v4_masked[i] = LM_in_v4[i] & {WW_v4{LM_sel[i]}};
assign LM_in_v5_masked[i] = LM_in_v5[i] & {WW_v5{LM_sel[i]}};
assign LM_in_v6_masked[i] = LM_in_v6[i] & {WW_v6{LM_sel[i]}};
assign LM_in_v7_masked[i] = LM_in_v7[i] & {WW_v7{LM_sel[i]}};
assign LM_in_v8_masked[i] = LM_in_v8[i] & {WW_v8{LM_sel[i]}};
assign LM_in_v9_masked[i] = LM_in_v9[i] & {WW_v9{LM_sel[i]}};
assign LM_in_v10_masked[i] = LM_in_v10[i] & {WW_v10{LM_sel[i]}};
assign LM_in_v11_masked[i] = LM_in_v11[i] & {WW_v11{LM_sel[i]}};
assign LM_in_v12_masked[i] = LM_in_v12[i] & {WW_v12{LM_sel[i]}};
assign LM_in_v13_masked[i] = LM_in_v13[i] & {WW_v13{LM_sel[i]}};
end
endgenerate
genvar j;
genvar k;
generate
for(k = 0; k < L; k = k + 1) begin: LM_transpose_loop_out
for (j = 0; j < WW_v0; j = j + 1) begin: LM_transpose_loop_in_v0
assign LM_in_v0_masked_T[j][k] = LM_in_v0_masked[k][j];
end
for (j = 0; j < WW_v1; j = j + 1) begin: LM_transpose_loop_in_v1
assign LM_in_v1_masked_T[j][k] = LM_in_v1_masked[k][j];
end
for (j = 0; j < WW_v2; j = j + 1) begin: LM_transpose_loop_in_v2
assign LM_in_v2_masked_T[j][k] = LM_in_v2_masked[k][j];
end
for (j = 0; j < WW_v3; j = j + 1) begin: LM_transpose_loop_in_v3
assign LM_in_v3_masked_T[j][k] = LM_in_v3_masked[k][j];
end
for (j = 0; j < WW_v4; j = j + 1) begin: LM_transpose_loop_in_v4
assign LM_in_v4_masked_T[j][k] = LM_in_v4_masked[k][j];
end
for (j = 0; j < WW_v5; j = j + 1) begin: LM_transpose_loop_in_v5
assign LM_in_v5_masked_T[j][k] = LM_in_v5_masked[k][j];
end
for (j = 0; j < WW_v6; j = j + 1) begin: LM_transpose_loop_in_v6
assign LM_in_v6_masked_T[j][k] = LM_in_v6_masked[k][j];
end
for (j = 0; j < WW_v7; j = j + 1) begin: LM_transpose_loop_in_v7
assign LM_in_v7_masked_T[j][k] = LM_in_v7_masked[k][j];
end
for (j = 0; j < WW_v8; j = j + 1) begin: LM_transpose_loop_in_v8
assign LM_in_v8_masked_T[j][k] = LM_in_v8_masked[k][j];
end
for (j = 0; j < WW_v9; j = j + 1) begin: LM_transpose_loop_in_v9
assign LM_in_v9_masked_T[j][k] = LM_in_v9_masked[k][j];
end
for (j = 0; j < WW_v10; j = j + 1) begin: LM_transpose_loop_in_v10
assign LM_in_v10_masked_T[j][k] = LM_in_v10_masked[k][j];
end
for (j = 0; j < WW_v11; j = j + 1) begin: LM_transpose_loop_in_v11
assign LM_in_v11_masked_T[j][k] = LM_in_v11_masked[k][j];
end
for (j = 0; j < WW_v12; j = j + 1) begin: LM_transpose_loop_in_v12
assign LM_in_v12_masked_T[j][k] = LM_in_v12_masked[k][j];
end
for (j = 0; j < WW_v13; j = j + 1) begin: LM_transpose_loop_in_v13
assign LM_in_v13_masked_T[j][k] = LM_in_v13_masked[k][j];
end
end
endgenerate
genvar m;
generate
for(m = 0; m < WW_v0; m = m + 1) begin: LM_reduce_or_loop_v0
assign LM_out_v0[m] = |(LM_in_v0_masked_T[m]);
end
for(m = 0; m < WW_v1; m = m + 1) begin: LM_reduce_or_loop_v1
assign LM_out_v1[m] = |(LM_in_v1_masked_T[m]);
end
for(m = 0; m < WW_v2; m = m + 1) begin: LM_reduce_or_loop_v2
assign LM_out_v2[m] = |(LM_in_v2_masked_T[m]);
end
for(m = 0; m < WW_v3; m = m + 1) begin: LM_reduce_or_loop_v3
assign LM_out_v3[m] = |(LM_in_v3_masked_T[m]);
end
for(m = 0; m < WW_v4; m = m + 1) begin: LM_reduce_or_loop_v4
assign LM_out_v4[m] = |(LM_in_v4_masked_T[m]);
end
for(m = 0; m < WW_v5; m = m + 1) begin: LM_reduce_or_loop_v5
assign LM_out_v5[m] = |(LM_in_v5_masked_T[m]);
end
for(m = 0; m < WW_v6; m = m + 1) begin: LM_reduce_or_loop_v6
assign LM_out_v6[m] = |(LM_in_v6_masked_T[m]);
end
for(m = 0; m < WW_v7; m = m + 1) begin: LM_reduce_or_loop_v7
assign LM_out_v7[m] = |(LM_in_v7_masked_T[m]);
end
for(m = 0; m < WW_v8; m = m + 1) begin: LM_reduce_or_loop_v8
assign LM_out_v8[m] = |(LM_in_v8_masked_T[m]);
end
for(m = 0; m < WW_v9; m = m + 1) begin: LM_reduce_or_loop_v9
assign LM_out_v9[m] = |(LM_in_v9_masked_T[m]);
end
for(m = 0; m < WW_v10; m = m + 1) begin: LM_reduce_or_loop_v10
assign LM_out_v10[m] = |(LM_in_v10_masked_T[m]);
end
for(m = 0; m < WW_v11; m = m + 1) begin: LM_reduce_or_loop_v11
assign LM_out_v11[m] = |(LM_in_v11_masked_T[m]);
end
for(m = 0; m < WW_v12; m = m + 1) begin: LM_reduce_or_loop_v12
assign LM_out_v12[m] = |(LM_in_v12_masked_T[m]);
end
for(m = 0; m < WW_v13; m = m + 1) begin: LM_reduce_or_loop_v13
assign LM_out_v13[m] = |(LM_in_v13_masked_T[m]);
end
endgenerate
endmodule
module WT_group
#(
parameter WW_v0 = 8,
parameter WW_v1 = 8,
parameter WW_v2 = 8,
parameter WW_v3 = 8,
parameter WW_v4 = 8,
parameter WW_v5 = 8,
parameter WW_v6 = 8,
parameter WW_v7 = 8,
parameter WW_v8 = 8,
parameter WW_v9 = 8,
parameter WW_v10 = 8,
parameter WW_v11 = 8,
parameter WW_v12 = 8,
parameter WW_v13 = 8
)
(
input clk,
input tree_rstn,
input valid,
input [WW_v0 - 1 : 0] WT_v0_in,
input [WW_v1 - 1 : 0] WT_v1_in,
input [WW_v2 - 1 : 0] WT_v2_in,
input [WW_v3 - 1 : 0] WT_v3_in,
input [WW_v4 - 1 : 0] WT_v4_in,
input [WW_v5 - 1 : 0] WT_v5_in,
input [WW_v6 - 1 : 0] WT_v6_in,
input [WW_v7 - 1 : 0] WT_v7_in,
input [WW_v8 - 1 : 0] WT_v8_in,
input [WW_v9 - 1 : 0] WT_v9_in,
input [WW_v10 - 1 : 0] WT_v10_in,
input [WW_v11 - 1 : 0] WT_v11_in,
input [WW_v12 - 1 : 0] WT_v12_in,
input [WW_v13 - 1 : 0] WT_v13_in,
output WT_v0_out_S,
output WT_v0_out_C,
output WT_v1_out_S,
output WT_v1_out_C,
output WT_v2_out_S,
output WT_v2_out_C,
output WT_v3_out_S,
output WT_v3_out_C,
output WT_v4_out_S,
output WT_v4_out_C,
output WT_v5_out_S,
output WT_v5_out_C,
output WT_v6_out_S,
output WT_v6_out_C,
output WT_v7_out_S,
output WT_v7_out_C,
output WT_v8_out_S,
output WT_v8_out_C,
output WT_v9_out_S,
output WT_v9_out_C,
output WT_v10_out_S,
output WT_v10_out_C,
output WT_v11_out_S,
output WT_v11_out_C,
output WT_v12_out_S,
output WT_v12_out_C,
output WT_v13_out_S,
output WT_v13_out_C
);
SerialWallaceTree8Input serial_wallace_tree_v0 // module name: DG
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v0_in),
.out_S(WT_v0_out_S),
.out_Cout(WT_v0_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v1
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v1_in),
.out_S(WT_v1_out_S),
.out_Cout(WT_v1_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v2
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v2_in),
.out_S(WT_v2_out_S),
.out_Cout(WT_v2_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v3
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v3_in),
.out_S(WT_v3_out_S),
.out_Cout(WT_v3_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v4
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v4_in),
.out_S(WT_v4_out_S),
.out_Cout(WT_v4_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v5
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v5_in),
.out_S(WT_v5_out_S),
.out_Cout(WT_v5_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v6
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v6_in),
.out_S(WT_v6_out_S),
.out_Cout(WT_v6_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v7
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v7_in),
.out_S(WT_v7_out_S),
.out_Cout(WT_v7_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v8
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v8_in),
.out_S(WT_v8_out_S),
.out_Cout(WT_v8_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v9
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v9_in),
.out_S(WT_v9_out_S),
.out_Cout(WT_v9_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v10
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v10_in),
.out_S(WT_v10_out_S),
.out_Cout(WT_v10_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v11
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v11_in),
.out_S(WT_v11_out_S),
.out_Cout(WT_v11_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v12
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v12_in),
.out_S(WT_v12_out_S),
.out_Cout(WT_v12_out_C)
);
SerialWallaceTree8Input serial_wallace_tree_v13
(
.clk(clk),
.rstn(tree_rstn),
.valid(valid),
.addends(WT_v13_in),
.out_S(WT_v13_out_S),
.out_Cout(WT_v13_out_C)
);
endmodule
module Sub_wrapper // abbr. SW
#(
parameter WW_v0 = 8,
parameter WW_v1 = 8,
parameter WW_v2 = 8,
parameter WW_v3 = 8,
parameter WW_v4 = 8,
parameter WW_v5 = 8,
parameter WW_v6 = 8,
parameter WW_v7 = 8,
parameter WW_v8 = 8,
parameter WW_v9 = 8,
parameter WW_v10 = 8,
parameter WW_v11 = 8,
parameter WW_v12 = 8,
parameter WW_v13 = 8,
parameter H = 16,
parameter L = 5
)
(
input clk,
input tree_rstn,
input valid,
input [L - 1 : 0] LM_sel,
input [H - 1 : 0] SW_in,
output WT_v0_out_S,
output WT_v0_out_C,
output WT_v1_out_S,
output WT_v1_out_C,
output WT_v2_out_S,
output WT_v2_out_C,
output WT_v3_out_S,
output WT_v3_out_C,
output WT_v4_out_S,
output WT_v4_out_C,
output WT_v5_out_S,
output WT_v5_out_C,
output WT_v6_out_S,
output WT_v6_out_C,
output WT_v7_out_S,
output WT_v7_out_C,
output WT_v8_out_S,
output WT_v8_out_C,
output WT_v9_out_S,
output WT_v9_out_C,
output WT_v10_out_S,
output WT_v10_out_C,
output WT_v11_out_S,
output WT_v11_out_C,
output WT_v12_out_S,
output WT_v12_out_C,
output WT_v13_out_S,
output WT_v13_out_C
);
wire [WW_v0 - 1 : 0] HN_out_v0[L - 1 : 0];
wire [WW_v1 - 1 : 0] HN_out_v1[L - 1 : 0];
wire [WW_v2 - 1 : 0] HN_out_v2[L - 1 : 0];
wire [WW_v3 - 1 : 0] HN_out_v3[L - 1 : 0];
wire [WW_v4 - 1 : 0] HN_out_v4[L - 1 : 0];
wire [WW_v5 - 1 : 0] HN_out_v5[L - 1 : 0];
wire [WW_v6 - 1 : 0] HN_out_v6[L - 1 : 0];
wire [WW_v7 - 1 : 0] HN_out_v7[L - 1 : 0];
wire [WW_v8 - 1 : 0] HN_out_v8[L - 1 : 0];
wire [WW_v9 - 1 : 0] HN_out_v9[L - 1 : 0];
wire [WW_v10 - 1 : 0] HN_out_v10[L - 1 : 0];
wire [WW_v11 - 1 : 0] HN_out_v11[L - 1 : 0];
wire [WW_v12 - 1 : 0] HN_out_v12[L - 1 : 0];
wire [WW_v13 - 1 : 0] HN_out_v13[L - 1 : 0];
HN
#(
.WW_v0(WW_v0),
.WW_v1(WW_v1),
.WW_v2(WW_v2),
.WW_v3(WW_v3),
.WW_v4(WW_v4),
.WW_v5(WW_v5),
.WW_v6(WW_v6),
.WW_v7(WW_v7),
.WW_v8(WW_v8),
.WW_v9(WW_v9),
.WW_v10(WW_v10),
.WW_v11(WW_v11),
.WW_v12(WW_v12),
.WW_v13(WW_v13),
.H(H),
.L(L)
)
hn
(
.HN_in(SW_in),
.HN_out_v0(HN_out_v0),
.HN_out_v1(HN_out_v1),
.HN_out_v2(HN_out_v2),
.HN_out_v3(HN_out_v3),
.HN_out_v4(HN_out_v4),
.HN_out_v5(HN_out_v5),
.HN_out_v6(HN_out_v6),
.HN_out_v7(HN_out_v7),
.HN_out_v8(HN_out_v8),
.HN_out_v9(HN_out_v9),
.HN_out_v10(HN_out_v10),
.HN_out_v11(HN_out_v11),
.HN_out_v12(HN_out_v12),
.HN_out_v13(HN_out_v13)
);
wire [WW_v0 - 1 : 0] LM_out_v0;
wire [WW_v1 - 1 : 0] LM_out_v1;
wire [WW_v2 - 1 : 0] LM_out_v2;
wire [WW_v3 - 1 : 0] LM_out_v3;
wire [WW_v4 - 1 : 0] LM_out_v4;
wire [WW_v5 - 1 : 0] LM_out_v5;
wire [WW_v6 - 1 : 0] LM_out_v6;
wire [WW_v7 - 1 : 0] LM_out_v7;
wire [WW_v8 - 1 : 0] LM_out_v8;
wire [WW_v9 - 1 : 0] LM_out_v9;
wire [WW_v10 - 1 : 0] LM_out_v10;
wire [WW_v11 - 1 : 0] LM_out_v11;
wire [WW_v12 - 1 : 0] LM_out_v12;
wire [WW_v13 - 1 : 0] LM_out_v13;
Layer_mux
#(
.L(L),
.WW_v0(WW_v0),
.WW_v1(WW_v1),
.WW_v2(WW_v2),
.WW_v3(WW_v3),
.WW_v4(WW_v4),
.WW_v5(WW_v5),
.WW_v6(WW_v6),
.WW_v7(WW_v7),
.WW_v8(WW_v8),
.WW_v9(WW_v9),
.WW_v10(WW_v10),
.WW_v11(WW_v11),
.WW_v12(WW_v12),
.WW_v13(WW_v13)
)
layer_mux
(
.LM_sel(LM_sel),
.LM_in_v0(HN_out_v0),
.LM_in_v1(HN_out_v1),
.LM_in_v2(HN_out_v2),
.LM_in_v3(HN_out_v3),
.LM_in_v4(HN_out_v4),
.LM_in_v5(HN_out_v5),
.LM_in_v6(HN_out_v6),
.LM_in_v7(HN_out_v7),
.LM_in_v8(HN_out_v8),
.LM_in_v9(HN_out_v9),
.LM_in_v10(HN_out_v10),
.LM_in_v11(HN_out_v11),
.LM_in_v12(HN_out_v12),
.LM_in_v13(HN_out_v13),
.LM_out_v0(LM_out_v0),
.LM_out_v1(LM_out_v1),
.LM_out_v2(LM_out_v2),
.LM_out_v3(LM_out_v3),
.LM_out_v4(LM_out_v4),
.LM_out_v5(LM_out_v5),
.LM_out_v6(LM_out_v6),
.LM_out_v7(LM_out_v7),
.LM_out_v8(LM_out_v8),
.LM_out_v9(LM_out_v9),
.LM_out_v10(LM_out_v10),
.LM_out_v11(LM_out_v11),
.LM_out_v12(LM_out_v12),
.LM_out_v13(LM_out_v13)
);
WT_group
#(
.WW_v0(WW_v0),
.WW_v1(WW_v1),
.WW_v2(WW_v2),
.WW_v3(WW_v3),
.WW_v4(WW_v4),
.WW_v5(WW_v5),
.WW_v6(WW_v6),
.WW_v7(WW_v7),
.WW_v8(WW_v8),
.WW_v9(WW_v9),
.WW_v10(WW_v10),
.WW_v11(WW_v11),
.WW_v12(WW_v12),
.WW_v13(WW_v13)
)
wt_group
(
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.WT_v0_in(LM_out_v0),
.WT_v1_in(LM_out_v1),
.WT_v2_in(LM_out_v2),
.WT_v3_in(LM_out_v3),
.WT_v4_in(LM_out_v4),
.WT_v5_in(LM_out_v5),
.WT_v6_in(LM_out_v6),
.WT_v7_in(LM_out_v7),
.WT_v8_in(LM_out_v8),
.WT_v9_in(LM_out_v9),
.WT_v10_in(LM_out_v10),
.WT_v11_in(LM_out_v11),
.WT_v12_in(LM_out_v12),
.WT_v13_in(LM_out_v13),
.WT_v0_out_S(WT_v0_out_S),
.WT_v0_out_C(WT_v0_out_C),
.WT_v1_out_S(WT_v1_out_S),
.WT_v1_out_C(WT_v1_out_C),
.WT_v2_out_S(WT_v2_out_S),
.WT_v2_out_C(WT_v2_out_C),
.WT_v3_out_S(WT_v3_out_S),
.WT_v3_out_C(WT_v3_out_C),
.WT_v4_out_S(WT_v4_out_S),
.WT_v4_out_C(WT_v4_out_C),
.WT_v5_out_S(WT_v5_out_S),
.WT_v5_out_C(WT_v5_out_C),
.WT_v6_out_S(WT_v6_out_S),
.WT_v6_out_C(WT_v6_out_C),
.WT_v7_out_S(WT_v7_out_S),
.WT_v7_out_C(WT_v7_out_C),
.WT_v8_out_S(WT_v8_out_S),
.WT_v8_out_C(WT_v8_out_C),
.WT_v9_out_S(WT_v9_out_S),
.WT_v9_out_C(WT_v9_out_C),
.WT_v10_out_S(WT_v10_out_S),
.WT_v10_out_C(WT_v10_out_C),
.WT_v11_out_S(WT_v11_out_S),
.WT_v11_out_C(WT_v11_out_C),
.WT_v12_out_S(WT_v12_out_S),
.WT_v12_out_C(WT_v12_out_C),
.WT_v13_out_S(WT_v13_out_S),
.WT_v13_out_C(WT_v13_out_C)
);
endmodule
\ No newline at end of file
module FullAdder(
input A, // First input bit
input B, // Second input bit
input Cin, // Carry input bit
output S, // Sum output bit
output Cout // Carry output bit
);
assign S = A ^ B ^ Cin;
assign Cout = (A & B) | (B & Cin) | (A & Cin);
endmodule
module WallaceTree8Input(
input [8 - 1 : 0] addends,
output [2 - 1 : 0] L0_Cout,
input [2 - 1 : 0] L1_Cin,
output [2 - 1 : 0] L1_Cout,
input [2 - 1 : 0] L2_Cin,
output [1 - 1 : 0] L2_Cout,
input [1 - 1 : 0] L3_Cin,
output final_Cout,
output final_S
);
wire [8 - 1 : 0] L0_all_inputs;
assign L0_all_inputs = addends;
wire [2 - 1 : 0] L0_remainder;
assign L0_remainder = L0_all_inputs[8 - 1 : 8 - 2];
wire [2 - 1 : 0] L0_S;
FullAdder L0_adders [2 - 1 : 0](
.A(L0_all_inputs[2 * 3 - 1 : 2 * 2]),
.B(L0_all_inputs[2 * 2 - 1 : 2]),
.Cin(L0_all_inputs[2 - 1 : 0]),
.Cout(L0_Cout),
.S(L0_S)
);
wire [6 - 1 : 0] L1_all_inputs;
assign L1_all_inputs = {L0_S, L1_Cin, L0_remainder};
wire [2 - 1 : 0] L1_S;
FullAdder L1_adders [2 - 1 : 0](
.A(L1_all_inputs[2 * 3 - 1 : 2 * 2]),
.B(L1_all_inputs[2 * 2 - 1 : 2]),
.Cin(L1_all_inputs[2 - 1 : 0]),
.Cout(L1_Cout),
.S(L1_S)
);
wire [4 - 1 : 0] L2_all_inputs;
assign L2_all_inputs = {L1_S, L2_Cin};
wire [1 - 1 : 0] L2_remainder;
assign L2_remainder = L2_all_inputs[4 - 1 : 4 - 1];
wire [1 - 1 : 0] L2_S;
FullAdder L2_adders [1 - 1 : 0](
.A(L2_all_inputs[1 * 3 - 1 : 1 * 2]),
.B(L2_all_inputs[1 * 2 - 1 : 1]),
.Cin(L2_all_inputs[1 - 1 : 0]),
.Cout(L2_Cout),
.S(L2_S)
);
wire [3 - 1 : 0] L3_all_inputs;
assign L3_all_inputs = {L2_S, L3_Cin, L2_remainder};
FullAdder L3_adders [1 - 1 : 0](
.A(L3_all_inputs[1 * 3 - 1 : 1 * 2]),
.B(L3_all_inputs[1 * 2 - 1 : 1]),
.Cin(L3_all_inputs[1 - 1 : 0]),
.Cout(final_Cout),
.S(final_S)
);
endmodule
module SerialWallaceTree8Input(
input clk,
input rstn,
input valid,
input [8 - 1 : 0] addends,
output out_S,
output out_Cout
);
wire [2 - 1 : 0] L0_Cout;
wire [2 - 1 : 0] L1_Cin;
reg [2 - 1 : 0] L0_Cout_L1_Cin_reg;
assign L1_Cin = L0_Cout_L1_Cin_reg;
wire [2 - 1 : 0] L1_Cout;
wire [2 - 1 : 0] L2_Cin;
reg [2 - 1 : 0] L1_Cout_L2_Cin_reg;
assign L2_Cin = L1_Cout_L2_Cin_reg;
wire [1 - 1 : 0] L2_Cout;
wire [1 - 1 : 0] L3_Cin;
reg [1 - 1 : 0] L2_Cout_L3_Cin_reg;
assign L3_Cin = L2_Cout_L3_Cin_reg;
wire final_S, final_Cout;
assign out_S = final_S & valid;
assign out_Cout = final_Cout & valid;
WallaceTree8Input u_WallaceTree8Input(
.addends(addends),
.L0_Cout(L0_Cout),
.L1_Cin(L1_Cin),
.L1_Cout(L1_Cout),
.L2_Cin(L2_Cin),
.L2_Cout(L2_Cout),
.L3_Cin(L3_Cin),
.final_S(final_S),
.final_Cout(final_Cout)
);
always @ (posedge clk) begin
if (!rstn) begin
L0_Cout_L1_Cin_reg <= 2'b0;
L1_Cout_L2_Cin_reg <= 2'b0;
L2_Cout_L3_Cin_reg <= 1'b0;
end
else if (valid) begin
L0_Cout_L1_Cin_reg <= L0_Cout;
L1_Cout_L2_Cin_reg <= L1_Cout;
L2_Cout_L3_Cin_reg <= L2_Cout;
end
end
endmodule
module Wrappers_tp_k #(
parameter H = 16, //这些数值无所谓
parameter L = 2,
parameter VN = 2
) (
input clk,
input tree_rstn,
input valid,
input CST_LOW,
input [L - 1 : 0] LM_sel,
input [H - 1 : 0] SW_in,
output WT_v0_out_S[VN - 1 : 0],
output WT_v0_out_C[VN - 1 : 0],
output WT_v1_out_S[VN - 1 : 0],
output WT_v1_out_C[VN - 1 : 0],
output WT_v2_out_S[VN - 1 : 0],
output WT_v2_out_C[VN - 1 : 0],
output WT_v3_out_S[VN - 1 : 0],
output WT_v3_out_C[VN - 1 : 0],
output WT_v4_out_S[VN - 1 : 0],
output WT_v4_out_C[VN - 1 : 0],
output WT_v5_out_S[VN - 1 : 0],
output WT_v5_out_C[VN - 1 : 0],
output WT_v6_out_S[VN - 1 : 0],
output WT_v6_out_C[VN - 1 : 0],
output WT_v7_out_S[VN - 1 : 0],
output WT_v7_out_C[VN - 1 : 0],
output WT_v8_out_S[VN - 1 : 0],
output WT_v8_out_C[VN - 1 : 0],
output WT_v9_out_S[VN - 1 : 0],
output WT_v9_out_C[VN - 1 : 0],
output WT_v10_out_S[VN - 1 : 0],
output WT_v10_out_C[VN - 1 : 0],
output WT_v11_out_S[VN - 1 : 0],
output WT_v11_out_C[VN - 1 : 0],
output WT_v12_out_S[VN - 1 : 0],
output WT_v12_out_C[VN - 1 : 0],
output WT_v13_out_S[VN - 1 : 0],
output WT_v13_out_C[VN - 1 : 0]
);
Mid_wrapper_tp_k_gp_0 mid_wrapper_0 //名称手动迭代下,还有下面的索引
(
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(SW_in),
.WT_0_out_S(WT_v0_out_S[15:0]),
.WT_0_out_C(WT_v0_out_C[0]),
.WT_1_out_S(WT_v1_out_S[0]),
.WT_1_out_C(WT_v1_out_C[0]),
.WT_2_out_S(WT_v2_out_S[0]),
.WT_2_out_C(WT_v2_out_C[0]),
.WT_3_out_S(WT_v3_out_S[0]),
.WT_3_out_C(WT_v3_out_C[0]),
.WT_4_out_S(WT_v4_out_S[0]),
.WT_4_out_C(WT_v4_out_C[0]),
.WT_5_out_S(WT_v5_out_S[0]),
.WT_5_out_C(WT_v5_out_C[0]),
.WT_6_out_S(WT_v6_out_S[0]),
.WT_6_out_C(WT_v6_out_C[0]),
.WT_7_out_S(WT_v7_out_S[0]),
.WT_7_out_C(WT_v7_out_C[0]),
.WT_8_out_S(WT_v8_out_S[0]),
.WT_8_out_C(WT_v8_out_C[0]),
.WT_9_out_S(WT_v9_out_S[0]),
.WT_9_out_C(WT_v9_out_C[0]),
.WT_10_out_S(WT_v10_out_S[0]),
.WT_10_out_C(WT_v10_out_C[0]),
.WT_11_out_S(WT_v11_out_S[0]),
.WT_11_out_C(WT_v11_out_C[0]),
.WT_12_out_S(WT_v12_out_S[0]),
.WT_12_out_C(WT_v12_out_C[0]),
.WT_13_out_S(WT_v13_out_S[0]),
.WT_13_out_C(WT_v13_out_C[0])
);
Mid_wrapper_tp_k_gp_1 mid_wrapper_1 (
.clk(clk),
.tree_rstn(tree_rstn),
.valid(valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(SW_in),
.WT_0_out_S(WT_v0_out_S[1]),
.WT_0_out_C(WT_v0_out_C[1]),
.WT_1_out_S(WT_v1_out_S[1]),
.WT_1_out_C(WT_v1_out_C[1]),
.WT_2_out_S(WT_v2_out_S[1]),
.WT_2_out_C(WT_v2_out_C[1]),
.WT_3_out_S(WT_v3_out_S[1]),
.WT_3_out_C(WT_v3_out_C[1]),
.WT_4_out_S(WT_v4_out_S[1]),
.WT_4_out_C(WT_v4_out_C[1]),
.WT_5_out_S(WT_v5_out_S[1]),
.WT_5_out_C(WT_v5_out_C[1]),
.WT_6_out_S(WT_v6_out_S[1]),
.WT_6_out_C(WT_v6_out_C[1]),
.WT_7_out_S(WT_v7_out_S[1]),
.WT_7_out_C(WT_v7_out_C[1]),
.WT_8_out_S(WT_v8_out_S[1]),
.WT_8_out_C(WT_v8_out_C[1]),
.WT_9_out_S(WT_v9_out_S[1]),
.WT_9_out_C(WT_v9_out_C[1]),
.WT_10_out_S(WT_v10_out_S[1]),
.WT_10_out_C(WT_v10_out_C[1]),
.WT_11_out_S(WT_v11_out_S[1]),
.WT_11_out_C(WT_v11_out_C[1]),
.WT_12_out_S(WT_v12_out_S[1]),
.WT_12_out_C(WT_v12_out_C[1]),
.WT_13_out_S(WT_v13_out_S[1]),
.WT_13_out_C(WT_v13_out_C[1])
);
endmodule
def read_mem_file(filepath):
with open(filepath, 'r') as file:
lines = file.readlines()
# 去除每行数据中的下划线并转换为十进制数
mem_vector = [int(line.strip().replace('_', ''), 2) for line in lines]
return mem_vector
def read_weight_file(filepath):
with open(filepath, 'r') as file:
lines = file.readlines()
# 读取每个权重并转换为整数
weight_vector = [int(line.strip()) for line in lines]
return weight_vector
def read_result_file(filepath):
with open(filepath, 'r') as file:
# 读取第一行并将其作为二进制数进行解释
result_binary_str = file.readline().strip()
# 将二进制数作为有符号数转换为十进制数
result_decimal = int(result_binary_str, 2)
# 如果二进制数是负数,则需要进行二补码转换
if result_decimal >= 2 ** (len(result_binary_str) - 1):
result_decimal -= 2 ** len(result_binary_str)
return result_decimal
def vector_multiplication(mem_vector, weight_vector):
# 对位乘法
product_vector = [m * w for m, w in zip(mem_vector, weight_vector)]
return product_vector
def main():
mem_vector = read_mem_file('F:/another-D/vivao/vivado_project/project_10_wallace_FSM_MUX/mem.txt')
weight_vector = read_weight_file('F:/another-D/vivao/vivado_project/project_10_wallace_FSM_MUX/weight.txt')
# 对位乘法
product_vector = vector_multiplication(mem_vector, weight_vector)
print("对位乘法结果:")
for i, product in enumerate(product_vector):
print(f"mem[{i}] * weight[{i}] = {mem_vector[i]} * {weight_vector[i]} = {product}")
# 总和
total_sum = sum(product_vector)
print(f"\n总和: {total_sum}; 二进制表示: {bin(total_sum)}")
# 读取result.txt中的二进制数并转换为十进制数
result_decimal = read_result_file('F:/another-D/vivao/vivado_project/project_10_wallace_FSM_MUX/result.txt')
print(f"\n从result.txt读取的二进制数对应的十进制值为: {result_decimal}\n")
# 对比result.txt中的值与部分和总和
if result_decimal == total_sum:
print("相同✓\n")
else:
print("不同×\n")
if __name__ == "__main__":
main()
module optimized_mux_tp_o_vc_1_value_3_cnt_0 (
input [1536 - 1:0] in, // 64-bit input signals
input [5:0] sel, // 6-bit binary selector signal
output out // Selected output
);
//参考原有逻辑,然后没得连的就连0
endmodule
module mux_52to1_binary (
input [ 13 - 1: 0] in, // 64-bit input signals
input [ 5: 0] sel, // 6-bit binary selector signal
output out // Selected output
);
reg [ 52 - 1: 0] par_out; // parallel
always @(*) begin
if (sel == 6'b000000) begin
par_out[0] = in[0];
end else begin
par_out[0] = 'b0;
end
if (sel == 6'b000001) begin
par_out[1] = in[0];
end else begin
par_out[1] = 'b0;
end
if (sel == 6'b000010) begin
par_out[2] = in[0];
end else begin
par_out[2] = 'b0;
end
if (sel == 6'b000011) begin
par_out[3] = in[0];
end else begin
par_out[3] = 'b0;
end
if (sel == 6'b000100) begin
par_out[4] = in[1];
end else begin
par_out[4] = 'b0;
end
if (sel == 6'b000101) begin
par_out[5] = in[1];
end else begin
par_out[5] = 'b0;
end
if (sel == 6'b000110) begin
par_out[6] = in[1];
end else begin
par_out[6] = 'b0;
end
if (sel == 6'b000111) begin
par_out[7] = in[1];
end else begin
par_out[7] = 'b0;
end
if (sel == 6'b001000) begin
par_out[8] = in[2];
end else begin
par_out[8] = 'b0;
end
if (sel == 6'b001001) begin
par_out[9] = in[2];
end else begin
par_out[9] = 'b0;
end
if (sel == 6'b001010) begin
par_out[10] = in[2];
end else begin
par_out[10] ='b0;
end
if (sel == 6'b001011) begin
par_out[11] = in[2];
end else begin
par_out[11] = 'b0;
end
if (sel == 6'b001100) begin
par_out[12] = in[3];
end else begin
par_out[12] = 'b0;
end
if (sel == 6'b001101) begin
par_out[13] = in[3];
end else begin
par_out[13] = 'b0;
end
if (sel == 6'b001110) begin
par_out[14] = in[3];
end else begin
par_out[14] = 'b0;
end
if (sel == 6'b001111) begin
par_out[15] = in[3];
end else begin
par_out[15] = 'b0;
end
if (sel == 6'b010000) begin
par_out[16] = in[4];
end else begin
par_out[16] = 'b0;
end
if (sel == 6'b010001) begin
par_out[17] = in[4];
end else begin
par_out[17] = 'b0;
end
if (sel == 6'b010010) begin
par_out[18] = in[4];
end else begin
par_out[18] = 'b0;
end
if (sel == 6'b010011) begin
par_out[19] = in[4];
end else begin
par_out[19] = 'b0;
end
if (sel == 6'b010100) begin
par_out[20] = in[5];
end else begin
par_out[20] = 'b0;
end
if (sel == 6'b010101) begin
par_out[21] = in[5];
end else begin
par_out[21] = 'b0;
end
if (sel == 6'b010110) begin
par_out[22] = in[5];
end else begin
par_out[22] = 'b0;
end
if (sel == 6'b010111) begin
par_out[23] = in[5];
end else begin
par_out[23] = 'b0;
end
if (sel == 6'b011000) begin
par_out[24] = in[6];
end else begin
par_out[24] = 'b0;
end
if (sel == 6'b011001) begin
par_out[25] = in[6];
end else begin
par_out[25] = 'b0;
end
if (sel == 6'b011010) begin
par_out[26] = in[6];
end else begin
par_out[26] = 'b0;
end
if (sel == 6'b011011) begin
par_out[27] = in[6];
end else begin
par_out[27] = 'b0;
end
if (sel == 6'b011100) begin
par_out[28] = in[7];
end else begin
par_out[28] = 'b0;
end
if (sel == 6'b011101) begin
par_out[29] = in[7];
end else begin
par_out[29] = 'b0;
end
if (sel == 6'b011110) begin
par_out[30] = in[7];
end else begin
par_out[30] = 'b0;
end
if (sel == 6'b011111) begin
par_out[31] = in[7];
end else begin
par_out[31] = 'b0;
end
if (sel == 6'b100000) begin
par_out[32] = in[8];
end else begin
par_out[32] = 'b0;
end
if (sel == 6'b100001) begin
par_out[33] = in[8];
end else begin
par_out[33] = 'b0;
end
if (sel == 6'b100010) begin
par_out[34] = in[8];
end else begin
par_out[34] = 'b0;
end
if (sel == 6'b100011) begin
par_out[35] = in[8];
end else begin
par_out[35] = 'b0;
end
if (sel == 6'b100100) begin
par_out[36] = in[9];
end else begin
par_out[36] = 'b0;
end
if (sel == 6'b100101) begin
par_out[37] = in[9];
end else begin
par_out[37] = 'b0;
end
if (sel == 6'b100110) begin
par_out[38] = in[9];
end else begin
par_out[38] = 'b0;
end
if (sel == 6'b100111) begin
par_out[39] = in[9];
end else begin
par_out[39] = 'b0;
end
if (sel == 6'b101000) begin
par_out[40] = in[10];
end else begin
par_out[40] = 'b0;
end
if (sel == 6'b101001) begin
par_out[41] = in[10];
end else begin
par_out[41] = 'b0;
end
if (sel == 6'b101010) begin
par_out[42] = in[10];
end else begin
par_out[42] = 'b0;
end
if (sel == 6'b101011) begin
par_out[43] = in[10];
end else begin
par_out[43] = 'b0;
end
if (sel == 6'b101100) begin
par_out[44] = in[11];
end else begin
par_out[44] = 'b0;
end
if (sel == 6'b101101) begin
par_out[45] = in[11];
end else begin
par_out[45] = 'b0;
end
if (sel == 6'b101110) begin
par_out[46] = in[11];
end else begin
par_out[46] = 'b0;
end
if (sel == 6'b101111) begin
par_out[47] = in[11];
end else begin
par_out[47] = 'b0;
end
if (sel == 6'b110000) begin
par_out[48] = in[12];
end else begin
par_out[48] = 'b0;
end
if (sel == 6'b110001) begin
par_out[49] = in[12];
end else begin
par_out[49] = 'b0;
end
if (sel == 6'b110010) begin
par_out[50] = in[12];
end else begin
par_out[50] = 'b0;
end
if (sel == 6'b110011) begin
par_out[51] = in[12];
end else begin
par_out[51] = 'b0;
end
end
assign out = par_out[0] |par_out[1] |par_out[2] |par_out[3] |par_out[4] |par_out[5] |par_out[6] |par_out[7] |par_out[8] |par_out[9] |par_out[10] |par_out[11] |par_out[12] |par_out[13] |par_out[14] |par_out[15] |par_out[16] |par_out[17] |par_out[18] |par_out[19] |par_out[20] |par_out[21] |par_out[22] |par_out[23] |par_out[24] |par_out[25] |par_out[26] |par_out[27] |par_out[28] |par_out[29] |par_out[30] |par_out[31] |par_out[32] |par_out[33] |par_out[34] |par_out[35] |par_out[36] |par_out[37] |par_out[38] |par_out[39] |par_out[40] |par_out[41] |par_out[42] |par_out[43] |par_out[44] |par_out[45] |par_out[46] |par_out[47] |par_out[48] |par_out[49] |par_out[50] |par_out[51];
endmodule
module optimized_mux_wrappaer_tp_o_vc_1_value_3 (
input [1536 - 1:0] in, // 64-bit input signals
input [5:0] sel, // 6-bit binary selector signal
output [max_color - 1:0] out // Selected output
);
//参考原有逻辑,然后没得连的就连0
optimized_mux_tp_o_vc_1_value_3_cnt_0 mux_0 (
.in (in),
.sel(sel),
.out(out[0])
);
optimized_mux_tp_o_vc_1_value_3_cnt_1 mux_1 (
.in (in),
.sel(sel),
.out(out[1])
);
optimized_mux_tp_o_vc_1_value_3_cnt_2 mux_2 (
.in (in),
.sel(sel),
.out(out[2])
);
endmodule
// 20240917 测试top fsm
`timescale 1ns / 1ps
module Testbench;
reg [7 : 0] Top_in [15 : 0];
reg [1 : 0] LM_sel;
reg valid, clk, fsm_rstn;
wire [16 : 0] WT_result_acc[2 : 0];
wire result_valid;
integer fd, i;
top_FSM #(
.L(2),
.WW_0(8),
.WW_1(8),
.H(16),
.VN(3),
.AP(8),
.SCW(11),
.SCWB(4),
.TTW(16)
) fsm (
.clk(clk),
.valid(valid),
.fsm_rstn(fsm_rstn),
.LM_sel(LM_sel),
.Top_in(Top_in),
.WT_result_acc(WT_result_acc),
.result_valid(result_valid)
);
always #(5) clk = ~clk;
initial begin
fd = $fopen("C:/Users/night/Document/Code/H-LLM/text/result.txt", "w");
if (fd == 0) begin
$display("Error: Unable to open file.");
$finish;
end
$readmemb("C:/Users/night/Document/Code/H-LLM/text/mem.txt", Top_in);
LM_sel = {2'b10};
clk = 0;
fsm_rstn = 0;
#10 fsm_rstn = 1;
valid = 1;
#(10 * 25) $finish;
end
always @(posedge clk) begin
if (result_valid) begin
for (i = 0; i < 3; i = i + 1) begin
$fwrite(fd, "%b\n", WT_result_acc[i]);
end
$fclose(fd);
end
end
endmodule
module wt_group_tp_o_vc_1 (
input [1536 - 1:0] in, // 64-bit input signals
input [5:0] sel, // 6-bit binary selector signal
output [max_color - 1:0] out // Selected output
);
endmodule
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment