Commit 08fb0edf by zhengzifu

Add FSM_tp_k_gp_0_0127 module and related templates; refactor output directory…

Add FSM_tp_k_gp_0_0127 module and related templates; refactor output directory handling in generation scripts
parent 26e7f36d
......@@ -536,10 +536,12 @@ def process_task(i, name, weights_file_name, matrix, H, L, VN, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -558,5 +560,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
......@@ -35,23 +35,24 @@ def process_weight(args):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
print(f"Processing {weights_file_name}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
matrixs = np.transpose(matrixs, (1, 0, 2))
VN, L, H = matrixs.shape
# print(VN, L, H)
path_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Generating color graph")
args = [
(i, weight, L, H, config.value_range, weights_file_name, path_dir)
(i, weight, L, H, config.value_range, weights_file_name, file_dir)
for i, weight in enumerate(matrixs)
]
with Pool(config.num_workers) as pool:
list(tqdm(pool.imap(process_weight, args), total=VN))
print("Generating color graph at", path_dir)
print("Generating color graph at", file_dir)
......@@ -107,10 +107,12 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -128,5 +130,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
......@@ -147,10 +147,12 @@ def process_task(i, name, weights_file_name, ww_list, H, L, VN, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -185,5 +187,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
......@@ -120,10 +120,12 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -141,7 +143,6 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
......
......@@ -106,10 +106,12 @@ def process_task(i, name, weights_file_name, H, L, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -125,5 +127,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
import shutil
import sys
import numpy as np
from pyrilog import (
......@@ -125,6 +126,10 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name, config: CFG):
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -142,5 +147,4 @@ def run(name, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
# %%
import shutil
import sys
import numpy as np
from pyrilog import (
......@@ -60,8 +61,8 @@ def generate_module(
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i * GN + GN - 1}:{i * GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i * GN + GN - 1}:{i * GN}]"
add_instance(
f"Mid_wrapper_tp_{weights_file_name}_gp_{i}",
f"Mid_wrapper_{i}",
......@@ -99,6 +100,10 @@ def process_task(i, name, weights_file_name, H, L, VN, config: CFG):
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -123,5 +128,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import shutil
import sys
import numpy as np
......@@ -106,9 +107,12 @@ def process_task(i, name, weights_file_name, H, L, config: CFG):
def run(name: str, config: CFG):
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.mapped_weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file_name}")
with open(weights_file, "rb") as f:
print(f"Loading {weights_file_name}")
......@@ -125,7 +129,6 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print(f"Generated {name} at {file_dir}")
......
......@@ -539,10 +539,9 @@ def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
output_dir = os.path.join(config.output_dir, name, weights_file_name)
shutil.rmtree(output_dir, ignore_errors=True)
if os.path.exists(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)
os.makedirs(output_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......
......@@ -86,10 +86,12 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -107,5 +109,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
......@@ -106,10 +106,12 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -127,5 +129,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
......@@ -139,10 +139,12 @@ def process_task(i, name, weights_file_name, ww_list, H, L, VN, config: CFG):
def run(name: str, config: CFG):
shutil.rmtree(config.output_dir, ignore_errors=True)
os.makedirs(config.output_dir, exist_ok=True)
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -179,5 +181,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import shutil
import numpy as np
from pyrilog import (
ModuleBlock,
......@@ -133,6 +134,10 @@ def process_task(i, name, weights_file_name, matrix, H, L, config):
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -150,5 +155,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name)
print("Files generated in", file_dir)
import argparse
import shutil
import sys
import math
import os
......@@ -50,7 +51,7 @@ def gen_wallacetree(num_addends, full_adder_list, remainder_list, total_input_li
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1: # the final bit we will manually manage
cout_cin_code += f" output [{full_adder_count} - 1 : 0] L{i}_Cout,\n"
cout_cin_code += f" input [{full_adder_count} - 1 : 0] L{i+1}_Cin,\n"
cout_cin_code += f" input [{full_adder_count} - 1 : 0] L{i + 1}_Cin,\n"
module_head_code = f"""module WallaceTree{num_addends}Input(
input [{num_addends} - 1 : 0] addends,
......@@ -70,9 +71,9 @@ def gen_wallacetree(num_addends, full_adder_list, remainder_list, total_input_li
else:
last_remainder_count = remainder_list[i - 1]
if last_remainder_count == 0:
concat_code = f"{{L{i-1}_S, L{i}_Cin}}"
concat_code = f"{{L{i - 1}_S, L{i}_Cin}}"
else:
concat_code = f"{{L{i-1}_S, L{i}_Cin, L{i-1}_remainder}}"
concat_code = f"{{L{i - 1}_S, L{i}_Cin, L{i - 1}_remainder}}"
code += f" assign L{i}_all_inputs = {concat_code};\n"
......@@ -121,15 +122,17 @@ def gen_serialwallacetree(num_addends, full_adder_list):
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
code += f" wire [{full_adder_count} - 1 : 0] L{i}_Cout;\n"
code += f" wire [{full_adder_count} - 1 : 0] L{i+1}_Cin;\n"
code += f" reg [{full_adder_count} - 1 : 0] L{i}_Cout_L{i+1}_Cin_reg;\n"
code += f" assign L{i+1}_Cin = L{i}_Cout_L{i+1}_Cin_reg;\n\n"
code += f" wire [{full_adder_count} - 1 : 0] L{i + 1}_Cin;\n"
code += (
f" reg [{full_adder_count} - 1 : 0] L{i}_Cout_L{i + 1}_Cin_reg;\n"
)
code += f" assign L{i + 1}_Cin = L{i}_Cout_L{i + 1}_Cin_reg;\n\n"
cin_cout_assign_code = ""
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
cin_cout_assign_code += f" .L{i}_Cout(L{i}_Cout),\n"
cin_cout_assign_code += f" .L{i+1}_Cin(L{i+1}_Cin),\n"
cin_cout_assign_code += f" .L{i + 1}_Cin(L{i + 1}_Cin),\n"
code += " wire final_S, final_Cout;\n"
code += " assign out_S = final_S & valid;\n"
......@@ -149,9 +152,9 @@ def gen_serialwallacetree(num_addends, full_adder_list):
for i, full_adder_count in enumerate(full_adder_list):
if i != len(full_adder_list) - 1:
reset_code += (
f" L{i}_Cout_L{i+1}_Cin_reg <= {full_adder_count}'b0;\n"
f" L{i}_Cout_L{i + 1}_Cin_reg <= {full_adder_count}'b0;\n"
)
reg_assign_code += f" L{i}_Cout_L{i+1}_Cin_reg <= L{i}_Cout&{{{full_adder_count}{{valid}}}};\n"
reg_assign_code += f" L{i}_Cout_L{i + 1}_Cin_reg <= L{i}_Cout&{{{full_adder_count}{{valid}}}};\n"
code += f"""\
always @ (posedge clk or negedge rstn) begin
......@@ -196,11 +199,11 @@ def run(name: str, config: CFG):
else:
output_code = wallace_tree_code + serial_wallace_tree_code
# Output the code to a file
os.makedirs(os.path.join(config.output_dir, name), exist_ok=True)
output_filename = os.path.join(
config.output_dir, name, f"SerialWallaceTree{i}Input.v"
)
output_dir = os.path.join(config.output_dir, name)
if os.path.exists(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)
os.makedirs(output_dir, exist_ok=True)
output_filename = os.path.join(output_dir, f"SerialWallaceTree{i}Input.v")
with open(output_filename, "w") as file:
file.write(output_code)
file_dir = os.path.join(config.output_dir, name)
print("Files generated in", file_dir)
print("Files generated in", output_dir)
# %%
import shutil
import sys
import numpy as np
from pyrilog import (
......@@ -59,8 +60,8 @@ def generate_module(
"LM_sel": "LM_sel",
}
for j in range(len(value_range)):
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i*GN+GN-1}:{i*GN}]"
sw_ports[f"WT_{j}_out_S"] = f"WT_{j}_out_S[{i * GN + GN - 1}:{i * GN}]"
sw_ports[f"WT_{j}_out_C"] = f"WT_{j}_out_C[{i * GN + GN - 1}:{i * GN}]"
add_instance(
f"Mid_wrapper_tp_{weights_file_name}_gp_{i}",
f"Mid_wrapper_{i}",
......@@ -98,6 +99,10 @@ def process_task(i, name, weights_file_name, H, L, VN, config: CFG):
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -122,5 +127,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
# %%
import sys
import shutil
import numpy as np
from pyrilog import (
VerilogGenerator,
......@@ -90,6 +91,10 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG):
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -107,5 +112,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
import sys
import shutil
import numpy as np
from pyrilog import (
VerilogGenerator,
......@@ -37,6 +38,10 @@ def process_task(i, name, weights_file_name, matrix, H, L, config: CFG = None):
def run(name: str, config: CFG):
weights_file = os.path.join(config.weights_dir, config.run_weights)
weights_file_name = os.path.splitext(os.path.basename(weights_file))[0]
file_dir = os.path.join(config.output_dir, name, weights_file_name)
if os.path.exists(file_dir):
shutil.rmtree(file_dir, ignore_errors=True)
os.makedirs(file_dir, exist_ok=True)
print(f"Processing {weights_file}")
with open(weights_file, "rb") as f:
matrixs = pickle.load(f)
......@@ -54,5 +59,4 @@ def run(name: str, config: CFG):
result = future.result()
except Exception as e:
print(f"Generating {result} failed with an error: {e}")
file_dir = os.path.join(config.output_dir, name, weights_file_name)
print("Files generated in", file_dir)
module QW_FSM_tp_k_gp_0_0127_VN32_dontuse #(
parameter H = 896, // hidden layer dim, 1536
parameter L = 24, // layer num, 52
parameter VN = 32, // vector num, 1536 or 3840(in FFN)
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 28 // MAC output total width, SCW + WP + 4
) (
input clk,
input core_valid,
input rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
output reg [TTW - 1 : 0] WT_result_acc[VN - 1 : 0],
output reg result_valid
);
reg [AP - 1 : 0] TM_sel;
wire [ H - 1 : 0] TM_out;
Top_mux #(
.H (H),
.AP(AP)
) top_mux (
.TM_sel(TM_sel),
.TM_in (Top_in),
.TM_out(TM_out)
);
wire WT_v0_out_S[VN - 1 : 0];
wire WT_v0_out_C[VN - 1 : 0];
wire WT_v1_out_S[VN - 1 : 0];
wire WT_v1_out_C[VN - 1 : 0];
wire WT_v2_out_S[VN - 1 : 0];
wire WT_v2_out_C[VN - 1 : 0];
wire WT_v3_out_S[VN - 1 : 0];
wire WT_v3_out_C[VN - 1 : 0];
wire WT_v4_out_S[VN - 1 : 0];
wire WT_v4_out_C[VN - 1 : 0];
wire WT_v5_out_S[VN - 1 : 0];
wire WT_v5_out_C[VN - 1 : 0];
wire WT_v6_out_S[VN - 1 : 0];
wire WT_v6_out_C[VN - 1 : 0];
wire WT_v7_out_S[VN - 1 : 0];
wire WT_v7_out_C[VN - 1 : 0];
wire WT_v8_out_S[VN - 1 : 0];
wire WT_v8_out_C[VN - 1 : 0];
wire WT_v9_out_S[VN - 1 : 0];
wire WT_v9_out_C[VN - 1 : 0];
wire WT_v10_out_S[VN - 1 : 0];
wire WT_v10_out_C[VN - 1 : 0];
wire WT_v11_out_S[VN - 1 : 0];
wire WT_v11_out_C[VN - 1 : 0];
wire WT_v12_out_S[VN - 1 : 0];
wire WT_v12_out_C[VN - 1 : 0];
wire WT_v13_out_S[VN - 1 : 0];
wire WT_v13_out_C[VN - 1 : 0];
reg tree_valid; // carry reg clear if not valid
reg [SCW - 1 : 0] final_S_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v0[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v1[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v2[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v3[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v4[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v5[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v6[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v7[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v8[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v9[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v10[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v11[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v12[VN - 1 : 0];
reg [SCW - 1 : 0] final_S_v13[VN - 1 : 0];
reg [SCW - 1 : 0] final_C_v13[VN - 1 : 0];
wire [TTW - 1 : 0] MAC_out[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v0[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v1[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v2[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v3[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v4[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v5[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v6[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v7[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v8[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v9[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v10[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v11[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v12[VN - 1 : 0];
reg [SCW - 1 : 0] WT_result_v13[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v0_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v1_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v2_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v3_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v4_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v5_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v6_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v7_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v8_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v9_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v10_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v11_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v12_wire[VN - 1 : 0];
wire [SCW - 1 : 0] WT_result_v13_wire[VN - 1 : 0];
reg [7 : 0] state_idx; // MSB 3: state; LSB 5: idx
reg [7 : 0] next_state_idx;
wire [2 : 0] state;
wire [4 : 0] idx;
assign state = state_idx[7 : 5];
assign idx = state_idx[4 : 0];
reg CST_LOW;
genvar j;
integer i;
Mid_wrapper_tp_k_gp_0 #(
.H (H),
.L (L),
.VN(VN)
) mid_wrappers (
.clk(clk),
.tree_rstn(rstn),
.valid(tree_valid),
.CST_LOW(CST_LOW),
.LM_sel(LM_sel),
.SW_in(TM_out),
.WT_0_out_S(WT_v0_out_S),
.WT_0_out_C(WT_v0_out_C),
.WT_1_out_S(WT_v1_out_S),
.WT_1_out_C(WT_v1_out_C),
.WT_2_out_S(WT_v2_out_S),
.WT_2_out_C(WT_v2_out_C),
.WT_3_out_S(WT_v3_out_S),
.WT_3_out_C(WT_v3_out_C),
.WT_4_out_S(WT_v4_out_S),
.WT_4_out_C(WT_v4_out_C),
.WT_5_out_S(WT_v5_out_S),
.WT_5_out_C(WT_v5_out_C),
.WT_6_out_S(WT_v6_out_S),
.WT_6_out_C(WT_v6_out_C),
.WT_7_out_S(WT_v7_out_S),
.WT_7_out_C(WT_v7_out_C),
.WT_8_out_S(WT_v8_out_S),
.WT_8_out_C(WT_v8_out_C),
.WT_9_out_S(WT_v9_out_S),
.WT_9_out_C(WT_v9_out_C),
.WT_10_out_S(WT_v10_out_S),
.WT_10_out_C(WT_v10_out_C),
.WT_11_out_S(WT_v11_out_S),
.WT_11_out_C(WT_v11_out_C),
.WT_12_out_S(WT_v12_out_S),
.WT_12_out_C(WT_v12_out_C),
.WT_13_out_S(WT_v13_out_S),
.WT_13_out_C(WT_v13_out_C)
);
generate
for (j = 0; j < VN; j = j + 1) begin : inst_SW_loop
MAC #(
.W_1(SCW), // input 1 width
.W_2(WP), // input 2 width
.W_O(TTW), // output width
.NUM(16) // parallel width
) mac (
.clk(clk),
.rstn(rstn),
.MAC_in_1({
{SCW{1'b0}},
{SCW{1'b0}},
WT_result_v13[j],
WT_result_v12[j],
WT_result_v11[j],
WT_result_v10[j],
WT_result_v9[j],
WT_result_v8[j],
WT_result_v7[j],
WT_result_v6[j],
WT_result_v5[j],
WT_result_v4[j],
WT_result_v3[j],
WT_result_v2[j],
WT_result_v1[j],
WT_result_v0[j]
}),
//.MAC_in_2({weight_0, weight_1, -5'd6, -5'd4, -5'd3, -5'd2, -5'd1, 5'd0, 5'd1, 5'd2, 5'd3, 5'd4, 5'd6, 5'd8, 5'd12, 5'd0}),
.MAC_out(MAC_out[j])
);
end
endgenerate
genvar k;
generate
for (k = 0; k < VN; k = k + 1) begin
assign WT_result_v0_wire[k] = {1'b0, final_S_v0[k]} + {final_C_v0[k], 1'b0};
assign WT_result_v1_wire[k] = {1'b0, final_S_v1[k]} + {final_C_v1[k], 1'b0};
assign WT_result_v2_wire[k] = {1'b0, final_S_v2[k]} + {final_C_v2[k], 1'b0};
assign WT_result_v3_wire[k] = {1'b0, final_S_v3[k]} + {final_C_v3[k], 1'b0};
assign WT_result_v4_wire[k] = {1'b0, final_S_v4[k]} + {final_C_v4[k], 1'b0};
assign WT_result_v5_wire[k] = {1'b0, final_S_v5[k]} + {final_C_v5[k], 1'b0};
assign WT_result_v6_wire[k] = {1'b0, final_S_v6[k]} + {final_C_v6[k], 1'b0};
assign WT_result_v7_wire[k] = {1'b0, final_S_v7[k]} + {final_C_v7[k], 1'b0};
assign WT_result_v8_wire[k] = {1'b0, final_S_v8[k]} + {final_C_v8[k], 1'b0};
assign WT_result_v9_wire[k] = {1'b0, final_S_v9[k]} + {final_C_v9[k], 1'b0};
assign WT_result_v10_wire[k] = {1'b0, final_S_v10[k]} + {final_C_v10[k], 1'b0};
assign WT_result_v11_wire[k] = {1'b0, final_S_v11[k]} + {final_C_v11[k], 1'b0};
assign WT_result_v12_wire[k] = {1'b0, final_S_v12[k]} + {final_C_v12[k], 1'b0};
assign WT_result_v13_wire[k] = {1'b0, final_S_v13[k]} + {final_C_v13[k], 1'b0};
end
endgenerate
// fsm next state generation
always @(state_idx, core_valid) begin
case(state)
3'b000: begin
if (core_valid == 1) next_state_idx = 8'b00100000;
else next_state_idx = 0;
end
3'b001: begin
if (idx == SCW) next_state_idx = 8'b01000000;
else next_state_idx = state_idx + 1;
end
3'b010: next_state_idx = 8'b01100000;
3'b011: next_state_idx = 8'b10000000;
3'b100: begin
if (core_valid == 0) next_state_idx = 0; //改为wait out逻辑,适配顶层
else next_state_idx = 8'b10000000;
end
default: next_state_idx = 0;
endcase
end
// fsm state transfer
always @(posedge clk) begin
if(!rstn)
state_idx <= 0;
else
state_idx <= next_state_idx;
end
//fsm output
always @(posedge clk or negedge rstn) begin
if (!rstn) begin
//state <= 0;
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
CST_LOW <= 0;
TM_sel <= 8'b00000000;
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= 0;
end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
end else begin
if (state == 0) begin
//idx <= 0;
tree_valid <= 0;
result_valid <= 0;
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i] <= 0;
final_C_v0[i] <= 0;
final_S_v1[i] <= 0;
final_C_v1[i] <= 0;
final_S_v2[i] <= 0;
final_C_v2[i] <= 0;
final_S_v3[i] <= 0;
final_C_v3[i] <= 0;
final_S_v4[i] <= 0;
final_C_v4[i] <= 0;
final_S_v5[i] <= 0;
final_C_v5[i] <= 0;
final_S_v6[i] <= 0;
final_C_v6[i] <= 0;
final_S_v7[i] <= 0;
final_C_v7[i] <= 0;
final_S_v8[i] <= 0;
final_C_v8[i] <= 0;
final_S_v9[i] <= 0;
final_C_v9[i] <= 0;
final_S_v10[i] <= 0;
final_C_v10[i] <= 0;
final_S_v11[i] <= 0;
final_C_v11[i] <= 0;
final_S_v12[i] <= 0;
final_C_v12[i] <= 0;
final_S_v13[i] <= 0;
final_C_v13[i] <= 0;
WT_result_v0[i] <= 0;
WT_result_v1[i] <= 0;
WT_result_v2[i] <= 0;
WT_result_v3[i] <= 0;
WT_result_v4[i] <= 0;
WT_result_v5[i] <= 0;
WT_result_v6[i] <= 0;
WT_result_v7[i] <= 0;
WT_result_v8[i] <= 0;
WT_result_v9[i] <= 0;
WT_result_v10[i] <= 0;
WT_result_v11[i] <= 0;
WT_result_v12[i] <= 0;
WT_result_v13[i] <= 0;
end
/*
if (valid == 1) begin
state <= 1;
end
else begin
state <= 0;
end
*/
end
else if (state == 1) begin
tree_valid <= 1;
if (idx == 0) begin
TM_sel <= 8'b00000001;
end else begin
case (idx)
1: TM_sel <= 8'b00000010;
2: TM_sel <= 8'b00000100;
3: TM_sel <= 8'b00001000;
4: TM_sel <= 8'b00010000;
5: TM_sel <= 8'b00100000;
6: TM_sel <= 8'b01000000;
7: TM_sel <= 8'b10000000;
default: TM_sel <= 8'b10000000; // signed extension
endcase
//if (idx == SCW) begin
// state <= 2;
//end
for (i = 0; i < VN; i = i + 1) begin
final_S_v0[i][idx-1] <= WT_v0_out_S[i];
final_C_v0[i][idx-1] <= WT_v0_out_C[i];
final_S_v1[i][idx-1] <= WT_v1_out_S[i];
final_C_v1[i][idx-1] <= WT_v1_out_C[i];
final_S_v2[i][idx-1] <= WT_v2_out_S[i];
final_C_v2[i][idx-1] <= WT_v2_out_C[i];
final_S_v3[i][idx-1] <= WT_v3_out_S[i];
final_C_v3[i][idx-1] <= WT_v3_out_C[i];
final_S_v4[i][idx-1] <= WT_v4_out_S[i];
final_C_v4[i][idx-1] <= WT_v4_out_C[i];
final_S_v5[i][idx-1] <= WT_v5_out_S[i];
final_C_v5[i][idx-1] <= WT_v5_out_C[i];
final_S_v6[i][idx-1] <= WT_v6_out_S[i];
final_C_v6[i][idx-1] <= WT_v6_out_C[i];
final_S_v7[i][idx-1] <= WT_v7_out_S[i];
final_C_v7[i][idx-1] <= WT_v7_out_C[i];
final_S_v8[i][idx-1] <= WT_v8_out_S[i];
final_C_v8[i][idx-1] <= WT_v8_out_C[i];
final_S_v9[i][idx-1] <= WT_v9_out_S[i];
final_C_v9[i][idx-1] <= WT_v9_out_C[i];
final_S_v10[i][idx-1] <= WT_v10_out_S[i];
final_C_v10[i][idx-1] <= WT_v10_out_C[i];
final_S_v11[i][idx-1] <= WT_v11_out_S[i];
final_C_v11[i][idx-1] <= WT_v11_out_C[i];
final_S_v12[i][idx-1] <= WT_v12_out_S[i];
final_C_v12[i][idx-1] <= WT_v12_out_C[i];
final_S_v13[i][idx-1] <= WT_v13_out_S[i];
final_C_v13[i][idx-1] <= WT_v13_out_C[i];
end
end
//idx <= idx + 1;
end
else if (state == 2) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_v0[i] <= WT_result_v0_wire[i];
WT_result_v1[i] <= WT_result_v1_wire[i];
WT_result_v2[i] <= WT_result_v2_wire[i];
WT_result_v3[i] <= WT_result_v3_wire[i];
WT_result_v4[i] <= WT_result_v4_wire[i];
WT_result_v5[i] <= WT_result_v5_wire[i];
WT_result_v6[i] <= WT_result_v6_wire[i];
WT_result_v7[i] <= WT_result_v7_wire[i];
WT_result_v8[i] <= WT_result_v8_wire[i];
WT_result_v9[i] <= WT_result_v9_wire[i];
WT_result_v10[i] <= WT_result_v10_wire[i];
WT_result_v11[i] <= WT_result_v11_wire[i];
WT_result_v12[i] <= WT_result_v12_wire[i];
WT_result_v13[i] <= WT_result_v13_wire[i];
end
//idx <= 0;
//state <= 3;
tree_valid <= 0;
end
else if (state == 3) begin
// MAC adder
//state <= 4;
end
else if (state == 4) begin
for (i = 0; i < VN; i = i + 1) begin
WT_result_acc[i] <= MAC_out[i];
end
//state <= 5;
result_valid <= 1;
end
//if (state == 5) begin
// ouput, reserved cycle
//idx <= 0;
//state <= 0;
//tree_valid <= 0;
//end
end
end
endmodule
module FSM_wrapper_tp_k
#(
parameter H = 896, // hidden layer dim, down: 4864 other: 896
parameter L = 24, // layer num
parameter VN = 32, // vector num per group, down: 8 logits: 116 other: 32
parameter TTVN = 128, // total vector num, vary in kqv...
parameter WP = 5, // weight precision
parameter AP = 8, // activation precision
parameter SCW = 19, // related to global max WW, AP+log2(WW_max)
parameter SCWB = 5, // log2(SCW)
parameter TTW = 28 // MAC output total width, SCW + WP + 4
)
(
input clk,
input core_valid,
input rstn,
input [L - 1 : 0] LM_sel,
input [AP - 1 : 0] Top_in[H - 1 : 0],
// 下面的输出直接传出去,所以不是reg类型
output [TTW - 1 : 0] WT_result_acc[TTVN - 1 : 0],
output result_valid
);
endmodule
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment