Commit 81a85a07 by root

refactor reward and tool integration

parent 6432075e
......@@ -10,6 +10,7 @@
**/wandb
**/ckpt
**/tests/datasets
**/reports
*.slurm*
/*.sh*
**/tmp/
......
......@@ -46,6 +46,15 @@ make run-online
- `SandboxFusion/sandbox/utils/execution.py`里面那个`get_tmp_dir`可能需要自定义一下,比如腾讯云上面就需要考虑哪边存储空间大。
- `SandboxFusion/sandbox/utils/logging.py`里面的日志级别可能需要改一下,默认的DEBUG输出的东西有点多。
### eda工具相关
由于`verl``sandbox-runtime`两个环境都需要用到eda相关的代码,因此我把他给单独拎出来了,放在`eda_tools`文件夹下面。`verl``sandbox-runtime`这两个环境都需要进行如下安装:
```bash
cd eda_tools
pip install -e .
cd ..
```
### 数据预处理
由于原始的eurus2数据集缺一些东西,因此需要进行预处理。代码在`examples/data_preprocess/convert_eurus_tir.py`
......@@ -259,4 +268,7 @@ verl在`verl/utils/dataset/rl_dataset.py`的`_read_files_and_tokenize`函数里
## 例子(Inference)
verl目前的`verl/trainer/main_generation.py`不支持工具调用,我复制了一些RL训练的代码把它改的支持工具调用了,文件在`verl/trainer/custom_generation.py`,调用脚本在`generate_test.sh`
此外还对`verl/trainer/config/generation.yaml`进行了一些小修改,适配了一点配置。
\ No newline at end of file
此外还对`verl/trainer/config/generation.yaml`进行了一些小修改,适配了一点配置。
(**这个还没改好!!!**)`verl/trainer/custom_generation.py`中间有一些坑,比如要设置`torch.distributed`的后端,默认的GLOO有30分钟的通信等待时间上限,似乎会限制生成的总时间。
此外,需要加更详细的工具调用反馈,比如对于编译错误就需要反馈。这个在`verl/tools/sandbox_fusion_tools.py``SandboxFusionTool.execute_code`里面改。
\ No newline at end of file
......@@ -42,6 +42,7 @@ async def run_command_bare(command: str | List[str],
extra_env: Optional[Dict[str, str]] = {},
use_exec: bool = False,
preexec_fn=None) -> CommandRunResult:
# time.sleep(0)
print(f'running command {command}')
try:
logger.debug(f'running command {command}')
......@@ -70,7 +71,6 @@ async def run_command_bare(command: str | List[str],
if stdin is not None:
try:
if p.stdin:
print('p.stdin is', p.stdin)
p.stdin.write(stdin.encode())
# p.stdin.flush()
await p.stdin.drain()
......
......@@ -116,6 +116,16 @@ async def run_verilog(args: CodeRunArgs) -> CodeRunResult:
return await run_commands(compile_cmd, run_cmd, tmp_dir, {}, args)
from eda_tools.core import ppa_one_sample
async def run_verilog_ppa(args: CodeRunArgs) -> CodeRunResult:
# TODO
# need an async wrapper
# also need to return compile error and runtime error
pass
# ppa = ppa_one_sample(args.code)
# return CodeRunResult(compile_result=ppa)
async def run_lean(args: CodeRunArgs) -> CodeRunResult:
deps_dir = os.path.abspath(os.path.join(__file__, '../../../runtime/lean'))
with tempfile.TemporaryDirectory(dir=get_tmp_dir(), ignore_cleanup_errors=True) as tmp_dir:
......
由于有多个conda环境需要用这部分eda_tools代码,故把它封装成单独的包。
安装方式:`pip install -e .`
\ No newline at end of file
from .core import *
from .utils import *
\ No newline at end of file
from verl.utils.reward_score.codev_eval_toolkit.verify import eda_tools
from eda_tools.utils import eda_tools
import json
import re
import os
from tqdm.contrib.concurrent import process_map
from multiprocessing import Process, Queue
import psutil
import hashlib
import random
import platform
# import platform
# # 根据不同系统导入不同的文件锁模块
# if platform.system() == 'Windows':
......@@ -107,6 +106,97 @@ def verify_one_sample(gold_code, dut_code, uid=None):
return result
def ppa_one_sample(code, uid=None):
uid = code + str(random.randint(0,2147483647))
uid = hashlib.md5(uid.encode("utf-8")).hexdigest()
eda = eda_tools(quiet=True)
# v = eda_tools(quiet=False)
try:
top = eda.auto_top(code)
except Exception as e:
# exception in verification, gold code or dut code have syntax problems
# print("Parse error:", e.args)
return {"correct": False, "parse_error": e.args}
file_path = os.path.abspath(f"./tmp/testcase/{uid}_syn.v")
build_dir = os.path.abspath(f"./tmp/work/{uid}")
try:
if not os.path.exists("./tmp/testcase"):
os.makedirs("./tmp/testcase", exist_ok=True)
if not os.path.exists("./tmp/work"):
os.makedirs("./tmp/work", exist_ok=True)
if not os.path.exists(build_dir):
os.makedirs(build_dir, exist_ok=True)
finally:
# release_lock(f)
pass
with open(file_path, "w") as f:
f.write(code)
try:
# Properly call the method from utils (eda is instantiated now)
input_ports, output_ports, clock_port_polarity, reset_ports = eda.extract_golden_ports(
golden_path=file_path, # Path to temporary file (required)
golden_top=top # Top module name (required)
)
# 4. Safely handle clock ports (avoid empty set/multiple clocks exceptions)
if clock_port_polarity:
# Take first clock (handle multi-clock scenario, utils doesn't support multiple clocks)
first_clock = next(iter(clock_port_polarity))
clk_name = first_clock[0] # Extract clock name (e.g., "clk")
clk_edge = "posedge" if first_clock[1] == 1 else "negedge"
print(f"[INFO] Extracted clock for id={uid}: {clk_name} ({clk_edge})")
# Warn about multiple clocks
if len(clock_port_polarity) > 1:
print(f"[WARNING] Multiple clocks found for id={uid} (only use {clk_name})")
else:
print(f"[WARNING] No clock port found for id={uid}, use default 'clk'")
except Exception as e:
# Catch all exceptions (e.g., Yosys not installed, Verilog syntax errors, etc.)
print(f"[ERROR] Failed to extract ports for id={uid}: {str(e)}")
clk_name = "clk" # Keep default on exception
# 如果想生成testbench代码并运行,参考以下内容
result = None
base_dir = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
try:
result = eda.yosys_opensta_ppa(
job_name=f"syn_{uid}",
rtl_paths=[file_path],
top=top,
clk=clk_name,
tech="freepdk45",
timeout=60,
build_dir=build_dir,
yosys_script=f'{base_dir}/scripts/yosys.ys'
)
except Exception as e:
# print("Test error:", e.args)
result = {
"area": "N/A",
"power": "N/A",
"time": "N/A",
"premap_cells": "N/A",
"premap_wires": "N/A",
"postmap_cells": "N/A",
"postmap_wires": "N/A"
}
finally:
if os.path.exists(file_path):
os.remove(file_path)
if os.path.exists(build_dir):
os.system(f"rm -r {build_dir}")
return result
def kill_process_tree(pid):
parent = psutil.Process(pid)
children = parent.children(recursive=True) # 获取所有子进程
......@@ -115,9 +205,9 @@ def kill_process_tree(pid):
parent.terminate() # 终止父进程
def verify_one_sample_wrapper(args):
def run_function_with_timeout(func, *args):
def target(queue):
result = verify_one_sample(*args)
result = func(*args)
queue.put(result)
queue = Queue()
......@@ -126,13 +216,11 @@ def verify_one_sample_wrapper(args):
process.join(timeout=30)
if process.is_alive():
# 如果超时,终止进程
kill_process_tree(process.pid)
process.join()
print("Function timed out!")
return {"correct": False, "timeout": True}
else:
# 返回结果
return queue.get()
......@@ -148,42 +236,17 @@ def extract_verilog(verilog_code):
if __name__ == "__main__":
for part in range(16):
name = f"codev_dataset_165k_o1_part{part}"
with open(f"data/evolve/{name}.jsonl", "r") as f:
data_gold = list(map(json.loads, f.read().strip().splitlines()))
data_gold = [extract_verilog(x["response"][0]["content"]) for x in data_gold]
with open(f"results/evolve/sample/{name}.jsonl", "r") as f:
data_dut = list(map(json.loads, f.read().strip().splitlines()))
problem_ids = [x["problem_id"] for x in data_dut]
data_dut = [extract_verilog(x["response"][0]["content"]) for x in data_dut]
print(len(data_gold), len(data_dut), len(problem_ids))
assert len(data_dut) % len(data_gold) == 0
n_sample = len(data_dut) // len(data_gold)
testcases = []
for i, dut in enumerate(data_dut):
gold = data_gold[i // n_sample]
testcases.append((gold, dut, i))
# testcases = testcases[:1000]
if not os.path.exists("./tmp/testcase"):
os.makedirs("./tmp/testcase")
if not os.path.exists("./tmp/work"):
os.makedirs("./tmp/work")
# cpu_num = multiprocessing.cpu_count()
cpu_num = 64
# chunksize = max(len(testcases) // (cpu_num * 5), 1)
chunksize = 1
results = process_map(verify_one_sample_wrapper, testcases, max_workers=cpu_num, chunksize=chunksize)
for i in range(len(results)):
results[i]["problem_id"] = problem_ids[i]
with open(f"results/evolve/eval/{name}.jsonl", "w") as f:
f.write("\n".join(map(json.dumps, results)) + "\n")
print(f"{name}.jsonl is processed!!!")
# file = "/nfs_global/S/zhuyaoyu/projects/CodeV-o1/data/source/codev_dataset_165k_wo_module_head.jsonl"
file = "/nfs_global/datasets/codev/codev_dataset_165k_v3.jsonl"
import json
with open(file, "r") as f:
data = list(map(json.loads, f.read().strip().splitlines()))
sep = "============================================"
# 正确
example_ans = data[2]["response"]
example_output = f"<think></think> <answer>\n```verilog\n{example_ans}```\n</answer>"
# reward = compute_score(example_output, example_ans)
ppa = ppa_one_sample(example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{ppa}")
\ No newline at end of file
......@@ -16,7 +16,6 @@ from siliconcompiler.targets import (
skywater130_demo,
) # import predefined technology and flow target
def llm_request(prompt, temperature=0.5):
api_key = os.getenv("tencent_key")
base_url = "https://api.lkeap.cloud.tencent.com/v1"
......@@ -108,7 +107,7 @@ class eda_tools:
new_code = re.sub(note_pattern, "", verilog_code)
new_code = re.sub(r"(?:\s*?\n)+", "\n", new_code)
module_def_pattern = r"(module\s+)([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*\#\s*\([\s\S]*?\))?(\s*(?:\([^;]*\))?\s*;)([\s\S]*?)?(endmodule)"
module_defs = re.findall(module_def_pattern, new_code, re.DOTALL)
module_defs = re.findall(module_def_pattern, verilog_code, re.DOTALL)
if not module_defs:
raise Exception("No module found in auto_top().")
module_names = [m[1] for m in module_defs]
......@@ -119,7 +118,7 @@ class eda_tools:
this_mod_body = mod[4]
for submod in module_names:
if submod != this_mod_name:
module_instance_pattern = rf"({re.escape(submod)})(\s)(\s*\#\s*\([\s\S]*?\))?([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*(?:\([^;]*\))?\s*;)"
module_instance_pattern = rf"({submod})(\s)(\s*\#\s*\([\s\S]*?\))?([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*(?:\([^;]*\))?\s*;)"
module_instances = re.findall(
module_instance_pattern, this_mod_body, re.DOTALL
)
......@@ -138,7 +137,7 @@ class eda_tools:
new_code = re.sub(note_pattern, "", verilog_code)
new_code = re.sub(r"(?:\s*?\n)+", "\n", new_code)
module_def_pattern = r"(module\s+)([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*\#\s*\([\s\S]*?\))?(\s*(?:\([^;]*\))?\s*;)([\s\S]*?)?(endmodule)"
module_defs = re.findall(module_def_pattern, new_code, re.DOTALL)
module_defs = re.findall(module_def_pattern, verilog_code, re.DOTALL)
module_names = [m[1] for m in module_defs]
for submod in module_names:
module_instance_pattern = rf"({submod})(\s+)(\#\s*\([\s\S]*?\)\s*)?([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*(?:\([^;]*\))?\s*;)"
......@@ -146,7 +145,7 @@ class eda_tools:
new_code = re.sub(module_def_pattern, rf"\1\2{suffix}\3\4\5\6", new_code)
return new_code
def extract_golden_ports(self, golden_path, golden_top, timeout=60):
def extract_golden_ports(self, golden_path, golden_top):
"""
根据yosys的结果,提取golden模块的输入输出端口、时钟端口、复位端口。
golden_path: 参考设计的路径
......@@ -159,13 +158,12 @@ class eda_tools:
clock_port_polarity: 时钟端口名、上升沿/下降沿触发
reset_port_polarity_sync: 复位端口名、高低电平有效、同步/异步复位
"""
golden_top = golden_top.lstrip("\\")
yosys_script = f"read_verilog {golden_path}; prep -top {golden_top} -flatten; opt_dff -nodffe; json -compat-int; exec -- echo 'Happy new year~';"
yosys_result = subprocess.run(
["yosys", "-p", yosys_script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
timeout=60,
)
if yosys_result.stderr:
raise Exception(yosys_result.stderr.decode("utf-8"))
......@@ -179,8 +177,6 @@ class eda_tools:
ports_ids_dict = {}
input_port_width = set()
output_port_width = set()
if yosys_json["modules"] == {}:
raise Exception("No module found in yosys output after synthesis.")
for port_name in yosys_json["modules"][golden_top]["ports"]:
direction = yosys_json["modules"][golden_top]["ports"][port_name][
"direction"
......@@ -204,7 +200,7 @@ class eda_tools:
elif len(bits) > 1 and port_id[0] in bits:
return f"{port_name}[{bits.index(port_id[0])}]"
else:
return None
raise Exception(f"Cannot find port id {port_id[0]}")
for cell_id in yosys_json["modules"][golden_top]["cells"]:
cell = yosys_json["modules"][golden_top]["cells"][cell_id]
......@@ -212,9 +208,8 @@ class eda_tools:
if reg_ports == "CLK":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["CLK_POLARITY"]
clock_port_polarity.add((port_name, polarity))
polarity = cell["parameters"]["CLK_POLARITY"]
clock_port_polarity.add((port_name, polarity))
break
match cell["type"]:
case "$adff" | "$adffe" | "$adlatch":
......@@ -222,36 +217,27 @@ class eda_tools:
if reg_ports == "ARST":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["ARST_POLARITY"]
sync = False
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
polarity = cell["parameters"]["ARST_POLARITY"]
sync = False
reset_port_polarity_sync.add((port_name, polarity, sync))
break
case "$sdff" | "$sdffe" | "$sdffce":
for reg_ports in cell["connections"]:
if reg_ports == "SRST":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["SRST_POLARITY"]
sync = True
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
polarity = cell["parameters"]["SRST_POLARITY"]
sync = True
reset_port_polarity_sync.add((port_name, polarity, sync))
break
case "$dffsr" | "$dffsre" | "$dlatchsr" | "$sr":
for reg_ports in cell["connections"]:
if reg_ports == "SET" or reg_ports == "CLR":
if reg_ports == "SET" | "CLR":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"][f"{reg_ports}_POLARITY"]
sync = False
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
polarity = cell["parameters"][f"{reg_ports}_POLARITY"]
sync = False
reset_port_polarity_sync.add((port_name, polarity, sync))
break
case "$dlatch" | "$ff" | "$dff" | "$dffe" | "aldff" | "$aldffe":
pass
......@@ -367,7 +353,7 @@ class eda_tools:
# 生成随机化输入信号的task
randomize_inputs_lines = "\n ".join(
[
f"{port}_in = {{{', '.join(['$random(seed)']*math.ceil(width/32))}}};"
f"{port}_in = {{{', '.join(['$random']*math.ceil(width/32))}}};"
for port, width in input_port_width
if port not in [clock_port_name] + list(reset_port_names)
]
......@@ -409,7 +395,6 @@ class eda_tools:
(
"\n ".join(sync_reset_lines)
+ "\n # 10; toggle_clock; # 10; toggle_clock;\n "
+ "\n ".join(unset_lines)
)
if sync_reset_lines
else "" + "\n ".join(async_reset_lines + unset_lines)
......@@ -439,19 +424,9 @@ class eda_tools:
end
endtask
"""
count_errors_task = f"""// Task to count errors
task count_errors;
begin
if (trigger === 1'b1) begin
num_errors = num_errors + 1;
end
num_all = num_all + 1;
end
endtask
"""
# 生成随机复位信号的task
random_reset_lines = "\n ".join(
[f"{port}_in = $random(seed);" for port in reset_port_names]
[f"{port}_in = $random;" for port in reset_port_names]
)
random_reset_task = f"""// Task for random reset
......@@ -466,21 +441,17 @@ class eda_tools:
initial_block_lines = [
"// initial block for random tests and targed tests",
"initial begin",
' if (!$value$plusargs("seed=%d", seed)) seed = 0;',
f' if (!$value$plusargs("outerLoopNum=%d", outerLoopNum)) outerLoopNum = {self.random_seq_num};',
f' if (!$value$plusargs("innerLoopNum=%d", innerLoopNum)) innerLoopNum = {self.random_seq_steps};',
(
f" {clock_port_name}_in = {0 if clock_port_edge else 1};"
if clock_port_name
else ""
),
f" repeat (outerLoopNum) begin",
f" repeat ({self.random_seq_num}) begin",
" random_reset;" if reset_port_names else "",
" #100; count_errors;",
f" repeat (innerLoopNum) begin",
" #100;",
f" repeat ({self.random_seq_steps}) begin",
" #100; randomize_inputs;",
" #100; toggle_clock;" if clock_port_name else "",
" #100; count_errors;",
" end",
" end",
]
......@@ -488,21 +459,20 @@ class eda_tools:
initial_block_lines.append(" #100;")
for i in range(len(reset_task_list)):
initial_block_lines.append(
f" repeat (outerLoopNum) begin",
f" repeat ({self.random_seq_num}) begin",
)
initial_block_lines.append(f" reset_{i};")
initial_block_lines.append(f" #100; count_errors;")
initial_block_lines.append(f" reset_{i};")
initial_block_lines.append(f" #100;")
initial_block_lines.append(
f" repeat (innerLoopNum) begin",
f" repeat ({self.random_seq_steps}) begin",
)
initial_block_lines.append(f" #100; randomize_inputs;")
initial_block_lines.append(f" #100; randomize_inputs;")
(
initial_block_lines.append(f" #100; toggle_clock;")
initial_block_lines.append(f" #100; toggle_clock;")
if clock_port_name
else ""
)
initial_block_lines.append(f" #100; count_errors;")
initial_block_lines.append(f" end")
initial_block_lines.append(f" end")
initial_block_lines.append(f" end")
if self.use_directed_tests:
......@@ -515,12 +485,7 @@ class eda_tools:
else:
initial_block_lines.append(" directed_tests;")
initial_block_lines += [
' $display("Number of all tests: %d", num_all);',
' $display("Number of errors: %d", num_errors);',
' $display("Error rate: %.8f", num_errors/num_all);',
" if (num_errors == 0) begin",
' $display("All tests passed.");',
" end",
' $display("All tests passed.");',
" $finish;",
"end",
]
......@@ -544,11 +509,6 @@ module testbench;
{gate_output_defs}
reg trigger;
real num_all = 0;
real num_errors = 0;
integer seed;
integer outerLoopNum;
integer innerLoopNum;
{golden_top}{self.golden_suffix} gold (
{gold_port_mappings}
......@@ -562,8 +522,8 @@ module testbench;
{random_reset_task if reset_port_names else ""}
{randomize_inputs_task}
{directed_tests_task if self.use_directed_tests else ""}
{count_errors_task}
{initial_block}
{monitor_block}
endmodule
"""
return testbench_code
......@@ -684,7 +644,6 @@ Example of expected response format:
print("Failed to generate directed tests with LLM.")
with open(os.path.join(tb_dir, "tb.v"), "w") as f:
f.write(renamed_golden_code)
f.write("\n")
f.write(tb_module_code)
return renamed_golden_code, renamed_gate_code, tb_module_code
......@@ -695,10 +654,6 @@ Example of expected response format:
golden_top,
gate_top,
tb_dir,
seed=0,
outerLoopNum=None,
innerLoopNum=None,
timeout=60,
):
"""
一个wrapper,用于生成testbench并运行,输出测试结果,如果运行的输出中有"all tests passed."则测试通过,否则测试失败
......@@ -709,18 +664,10 @@ Example of expected response format:
golden_top: 参考设计的顶层模块名
gate_top: 待测设计的顶层模块名
tb_dir: 生成的testbench所在的路径,包括Makefile,verilator在这个路径下运行
seed: 控制testbench测试时的随机种子,默认0,和之前对齐
outerLoopNum: 控制testbench测试时的外层循环次数,默认None,此时采用self.random_seq_num和之前对齐
innerLoopNum: 控制testbench测试时的内层循环次数,默认None,此时采用self.random_seq_steps和之前对齐
timeout: 仿真超时时间,默认60s
输出:
返回值为True或False,表示测试通过或失败
"""
if outerLoopNum == None:
outerLoopNum = self.random_seq_num
if innerLoopNum == None:
innerLoopNum = self.random_seq_steps
(
input_port_width,
output_port_width,
......@@ -739,28 +686,20 @@ Example of expected response format:
reset_port_polarity_sync=reset_port_polarity_sync,
)
command = f"iverilog -g2012 -o {os.path.join(tb_dir,'tb.vvp')} -s testbench {os.path.join(tb_dir,'*.v')} && {os.path.join(tb_dir,f'tb.vvp +seed={seed} +outerLoopNum={outerLoopNum} +innerLoopNum={innerLoopNum}')}"
command = f"iverilog -g2012 -o {os.path.join(tb_dir,'tb.vvp')} -s testbench {os.path.join(tb_dir,'*.v')} && {os.path.join(tb_dir,'tb.vvp')}"
res = subprocess.run(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
timeout=60,
)
error_rate_pattern = r"Error rate:\s*(\d+\.\d+)\n"
print(res.stdout.decode("utf-8")) if not self.quiet else None
if re.search(error_rate_pattern, res.stdout.decode("utf-8")):
error_rate = float(
re.search(error_rate_pattern, res.stdout.decode("utf-8")).group(1)
)
else:
error_rate = 1.0
# TODO 实现得比较丑陋
if "All tests passed." in res.stdout.decode("utf-8"):
print("Test passed!") if not self.quiet else None
return (
True,
error_rate,
input_port_width,
output_port_width,
clock_port_polarity,
......@@ -771,7 +710,6 @@ Example of expected response format:
print("Test failed!") if not self.quiet else None
return (
False,
error_rate,
input_port_width,
output_port_width,
clock_port_polarity,
......@@ -787,7 +725,6 @@ Example of expected response format:
tech: str = "freepdk45",
timeout=60,
build_dir="./work/build",
cache_dir="./siliconcompiler",
) -> dict:
"""
用来综合一个设计,返回综合结果,包括cell_area, peak_power, arrival_time
......@@ -804,7 +741,6 @@ Example of expected response format:
"""
chip = Chip(top) # create chip object
chip.set("option", "builddir", build_dir)
chip.set("option", "cachedir", cache_dir)
chip.set("option", "jobname", job_name)
chip.set("option", "clean", True)
chip.set("option", "loglevel", "critical" if self.quiet else "info")
......@@ -843,6 +779,162 @@ Example of expected response format:
return ppa
###################################################################################
def yosys_opensta_ppa(
self,
job_name: str,
rtl_paths: list,
top: str,
clk: str = None,
tech: str = "freepdk45",
timeout: int = 60,
build_dir: str = "./work/build",
yosys_script: str = "yosys.ys"
) -> dict:
base_dir = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
ppa = {
"cell_area": None,
"peak_power": None,
"arrival_time": None,
"premap_cells": None,
"premap_wires": None,
"postmap_cells": None,
"postmap_wires": None
}
for path in rtl_paths:
file_path = path
match tech:
case "freepdk45":
t_ech = f"{base_dir}/tech/NangateOpenCellLibrary_typical.lib"
case _:
raise ValueError(f"Unsupported technology {tech}")
top_module = top
os.makedirs("reports", exist_ok=True)
file_id = file_path.split("/")[-1].split(".")[0]
report_dir = os.path.join("reports", file_id)
os.makedirs(report_dir, exist_ok=True)
power_path = os.path.join(report_dir, "power.rpt")
unconstrained_path = os.path.join(report_dir, "unconstrained.rpt")
with open(yosys_script, "r") as f:
ys_template = f.read()
ys_script = ys_template.format(
file_path=file_path,
top_module=top_module,
t_ech=t_ech
)
try:
result = subprocess.run(
["yosys"],
input=ys_script,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
timeout=timeout
)
except subprocess.TimeoutExpired:
print(f"Error: Yosys command timed out after {timeout} seconds.")
ppa["error"] = "timeout"
return ppa
except subprocess.CalledProcessError as e:
print(f"Yosys command failed with error:\n{e.stderr}")
ppa["error"] = "command_failed"
return ppa
if result.returncode == 0:
for line in result.stdout.split('\n'):
if "Chip area for" in line:
match = re.search(r'\d+\.\d+', line)
if match:
chip_area = float(match.group())
ppa["cell_area"] = chip_area
else:
print("No valid value found:", line.strip())
seen_premap_wires = False
seen_premap_cells = False
for line in result.stdout.split('\n'):
stripped = line.strip()
if stripped.startswith("Number of wires:"):
parts = stripped.split()
try:
count = int(parts[-1])
except (IndexError, ValueError):
print("No valid wire count found:", stripped)
continue
if not seen_premap_wires:
ppa["premap_wires"] = count
seen_premap_wires = True
else:
ppa["postmap_wires"] = count
elif stripped.startswith("Number of cells:"):
parts = stripped.split()
try:
count = int(parts[-1])
except (IndexError, ValueError):
print("No valid cell count found:", stripped)
continue
if not seen_premap_cells:
ppa["premap_cells"] = count
seen_premap_cells = True
else:
ppa["postmap_cells"] = count
with open(f"{base_dir}/scripts/sta.ys", "r") as f:
sta_template = f.read()
clock_constraint = ""
if clk is not None:
clock_constraint = f'create_clock -name {clk} -period 10 [get_ports {clk}]\n'
sta_script = sta_template.replace("{t_ech}", t_ech).replace("{top_module}", top_module).replace("{file_path}", file_path).replace("{file_id}", file_id)
if clock_constraint:
sta_script = sta_script.replace(
f'link_design $sc_design',
f'link_design $sc_design\n{clock_constraint}'
)
result = subprocess.run(
["sta"],
input=sta_script,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True
)
try:
with open(power_path, 'r') as f:
for line in f:
if line.startswith('Total'):
match = re.search(r'^Total\s+\S+\s+\S+\s+\S+\s+(\S+)', line)
if match:
ppa["peak_power"] = float(match.group(1)) * 1000
else:
print(f"No peak_power value found in {power_path}")
break
except FileNotFoundError:
print(f"File not found: {power_path}")
try:
with open(unconstrained_path, 'r') as f:
for line in f:
if 'data arrival time' in line:
match = re.search(r'^\s*(\d+\.\d+)\s+data arrival time', line)
if match:
ppa["arrival_time"] = float(match.group(1))
else:
print(f"No arrival_time value found in {unconstrained_path}")
break
except FileNotFoundError:
print(f"File not found: {unconstrained_path}")
return ppa
###################################################################################
class myLogger:
def __init__(self):
self.log = []
......@@ -894,17 +986,6 @@ def main():
f.write(eda.process_verilog(gate_code, eda.gate_suffix))
f.write("\n")
f.write(tb)
equiv, error_rate, _, _, _, _ = eda.equiv_with_testbench(
"./temp/gold.v",
"./temp/gate.v",
gold_top,
gate_top,
"./temp/testbench",
seed=0,
outerLoopNum=100,
innerLoopNum=1000,
)
print(equiv, error_rate)
if __name__ == "__main__":
......
(* techmap_celltype = "$fa" *)
module _tech_fa (A, B, C, X, Y);
parameter WIDTH = 1;
(* force_downto *)
input [WIDTH-1 : 0] A, B, C;
(* force_downto *)
output [WIDTH-1 : 0] X, Y;
parameter _TECHMAP_CONSTVAL_A_ = WIDTH'bx;
parameter _TECHMAP_CONSTVAL_B_ = WIDTH'bx;
parameter _TECHMAP_CONSTVAL_C_ = WIDTH'bx;
genvar i;
generate for (i = 0; i < WIDTH; i = i + 1) begin
if (_TECHMAP_CONSTVAL_A_[i] === 1'b0 || _TECHMAP_CONSTVAL_B_[i] === 1'b0 || _TECHMAP_CONSTVAL_C_[i] === 1'b0) begin
if (_TECHMAP_CONSTVAL_C_[i] === 1'b0) begin
HA_X1 halfadder_Cconst (
.A(A[i]),
.B(B[i]),
.CO(X[i]), .S(Y[i])
);
end
else begin
if (_TECHMAP_CONSTVAL_B_[i] === 1'b0) begin
HA_X1 halfadder_Bconst (
.A(A[i]),
.B(C[i]),
.CO(X[i]), .S(Y[i])
);
end
else begin
HA_X1 halfadder_Aconst (
.A(B[i]),
.B(C[i]),
.CO(X[i]), .S(Y[i])
);
end
end
end
else begin
FA_X1 fulladder (
.A(A[i]), .B(B[i]), .CI(C[i]), .CO(X[i]), .S(Y[i])
);
end
end endgenerate
endmodule
// Techmap file implementing Yosys generic latch using FreePDK45 standard cell.
// Source: https://github.com/The-OpenROAD-Project/OpenROAD-flow-scripts/blob/10f030fccea33ec8084bb3974024d840cabc3782/flow/platforms/nangate45/cells_latch.v
module $_DLATCH_P_(input E, input D, output Q);
DLH_X1 _TECHMAP_REPLACE_ (
.D(D),
.G(E),
.Q(Q)
);
endmodule
module $_DLATCH_N_(input E, input D, output Q);
DLL_X1 _TECHMAP_REPLACE_ (
.D(D),
.GN(E),
.Q(Q)
);
endmodule
\ No newline at end of file
module \$_TBUF_ (input A, input E, output Y);
TBUF_X1 _TECHMAP_REPLACE_ (
.A(A),
.Z(Y),
.EN(E));
endmodule
(* techmap_celltype = "$lcu" *)
module _80_lcu_kogge_stone (
P,
G,
CI,
CO
);
parameter WIDTH = 2;
(* force_downto *)
input [WIDTH-1:0] P, G;
input CI;
(* force_downto *)
output [WIDTH-1:0] CO;
integer i, j;
(* force_downto *)
reg [WIDTH-1:0] p, g;
wire [1023:0] _TECHMAP_DO_ = "proc; opt -fast";
always @* begin
p = P;
g = G;
// in almost all cases CI will be constant zero
g[0] = g[0] | (p[0] & CI);
for (i = 0; i < $clog2(WIDTH); i = i + 1) begin
for (j = WIDTH - 1; j >= 2 ** i; j = j - 1) begin
g[j] = g[j] | p[j] & g[j-2**i];
p[j] = p[j] & p[j-2**i];
end
end
end
assign CO = g;
endmodule
set_driving_cell BUF_X4
set_load 3.899
\ No newline at end of file
# Default constraints file that sets up clocks based on definitions in schema.
source sc_manifest.tcl > /dev/null
### Create clocks
if { [sc_cfg_exists datasheet pin] } {
set clock_idx 0
foreach pin [dict keys [sc_cfg_get datasheet pin]] {
if { [sc_cfg_get datasheet pin $pin type global] == "clock" } {
# If clock...
set periodtuple [sc_cfg_get datasheet pin $pin tperiod global]
set period [sta::time_sta_ui [lindex $periodtuple 1]]
set jittertuple [sc_cfg_get datasheet pin $pin tjitter global]
set jitter [sta::time_sta_ui [lindex $jittertuple 1]]
set clk_name "clk${clock_idx}"
incr clock_idx
set period_fmt \
[sta::format_time [sta::time_ui_sta $period] 3][sta::unit_scale_abbreviation time]
set jitter_fmt \
[sta::format_time [sta::time_ui_sta $jitter] 3][sta::unit_scale_abbreviation time]
puts \
"Creating clock $clk_name with ${period_fmt}s period and ${jitter_fmt}s jitter."
create_clock -name $clk_name -period $period $pin
set_clock_uncertainty $jitter [get_clock $clk_name]
}
}
}
### Create IO constraints
set sc_sdc_buffer []
if { [sc_cfg_tool_task_exists {var} sdc_buffer] } {
set sc_sdc_buffer [sc_cfg_tool_task_get {var} sdc_buffer]
}
set buffer_cell "NULL"
if { [llength $sc_sdc_buffer] == 0 } {
foreach cell [get_lib_cells *] {
if { [get_property $cell is_buffer] } {
# Find first buffer and use that as IO constraints
set buffer_cell $cell
break
}
}
} else {
set buffer_cell [get_lib_cells [lindex $sc_sdc_buffer 0]]
}
if { $buffer_cell != "NULL" && $buffer_cell != "" } {
puts "Using [get_name $buffer_cell] as the SDC IO constraint cell"
set driving_port "NULL"
set load_cap 0.0
set port_itr [$buffer_cell liberty_port_iterator]
while { [$port_itr has_next] } {
set port [$port_itr next]
set pcap [$port capacitance NULL max]
if { [get_property $port direction] == "input" } {
set load_cap [expr { 10 * $pcap }]
} elseif { [get_property $port direction] == "output" } {
set driving_port [get_name $port]
}
}
$port_itr finish
if { $load_cap > 0.0 } {
set cap_fmt [sta::format_capacitance $load_cap 3][sta::unit_scale_abbreviation capacitance]
puts "Setting output load constraint to ${cap_fmt}F."
set_load [sta::capacitance_sta_ui $load_cap] [all_outputs]
}
if { $driving_port != "NULL" } {
puts "Setting input driving pin constraint to [get_name $buffer_cell]/$driving_port."
set_driving_cell -lib_cell [$buffer_cell name] -pin $driving_port [all_inputs]
}
}
set base_dir [file dirname [file dirname [info script]]]
# =============================================================================
# Read Design and Libraries
# =============================================================================
read_verilog -noblackbox -sv {file_path} # Read RTL source
chparam -list {top_module} # Set top module parameters
stat # Initial design statistics
read_liberty -setattr liberty_cell -lib {t_ech} # Load timing library
# =============================================================================
# Core Synthesis
# =============================================================================
synth -top {top_module} -flatten # Flatten hierarchy and synthesize
synth -top {top_module} -run fine:check # Refine and check design
# =============================================================================
# ABC Optimization and Mapping
# =============================================================================
abc -fast -liberty {t_ech} # Fast logic optimization
abc -constr [file join $base_dir files_needed/sc_abc.constraints] \ # Constrained ABC mapping
-liberty {t_ech} \
-dont_use OAI211_X1 \
-dont_use CLKBUF_X1 \
-dont_use CLKBUF_X2 \
-dont_use CLKBUF_X3
# =============================================================================
# Sequential Element Mapping
# =============================================================================
dfflibmap -liberty {t_ech} # Map flip-flops to library
dfflegalize -cell $_DFF_P_ 01 \ # Legalize sequential cells
-cell $_DFF_PN0_ 01 \
-cell $_DFF_PN1_ 01 \
-cell $_DFFSR_PNN_ 01 \
t:$_DFF* t:$_SDFF*
# =============================================================================
# Final Output Generation
# =============================================================================
stat -liberty {t_ech} # Generate timing/area report
write_verilog -noexpr -nohex -nodec {file_path}g # Write final netlist
clean_zerowidth # Clean zero-width signals
echo off # Suppress redundant output
\ No newline at end of file
#**************************************************************
# Initialization and Parameter Settings
#**************************************************************
set sc_design {top_module}
set sc_scenarios "worst"
set sc_delaymodel "nldm"
set sc_sdc "files_needed/sc_constraints.sdc"
#**************************************************************
# Load Technology Libraries
#**************************************************************
puts "Defining timing corners: $sc_scenarios"
define_corners $sc_scenarios
set sc_targetlibs "{t_ech}"
puts "Reading liberty file for worst (typical): $sc_targetlibs"
read_liberty -corner worst $sc_targetlibs
#**************************************************************
# Load Design Netlist
#**************************************************************
read_verilog {file_path}g
link_design $sc_design
#**************************************************************
# Load Timing Constraints (SDC)
#**************************************************************
puts "Reading SDC: $sc_sdc"
read_sdc $sc_sdc
#**************************************************************
# Timing Analysis and Report Generation
#**************************************************************
# Retain only the critical report generation commands, others are commented out
# Setup Time Analysis (Setup) - Commented out unnecessary reports
# report_checks -path_delay max -format full_clock_expanded > reports/setup.rpt
# report_worst_slack -max > reports/worst_slack.setup.rpt
# report_tns > reports/tns.setup.rpt
# Retain unconstrained path report
puts "SC_METRIC: report_checks -unconstrained"
report_checks -unconstrained > reports/{file_id}/unconstrained.rpt
#**************************************************************
# Power Analysis
#**************************************************************
puts "SC_METRIC: report_power"
report_power -corner worst > reports/{file_id}/power.rpt
#**************************************************************
# Other Statistics (Optional Retention)
#**************************************************************
puts "SC_METRIC: cellarea"
if { [catch {sc_design_area} area] } {
puts "Error: Area calculation failed. Check design hierarchy."
} else {
puts "Total cell area: $area um²"
}
\ No newline at end of file
set base_dir [file dirname [file dirname [info script]]]
# =============================================================================
# Read source RTL and libraries
# =============================================================================
read_verilog -noblackbox -sv {file_path}
chparam -list {top_module}
stat
read_liberty -setattr liberty_cell -lib {t_ech}
# =============================================================================
# Set hierarchy and flattening
# =============================================================================
hierarchy -top {top_module}
tribuf
scratchpad -set flatten.separator /
synth -flatten -extra-map [file join $base_dir files_needed/lcu_kogge_stone.v] -top {top_module} -run begin:fine
hierarchy -check -top {top_module}
# =============================================================================
# Process lowering and cleaning
# =============================================================================
proc
proc_clean
proc_rmdead
proc_prune
proc_init
proc_arst
proc_rom
proc_mux
proc_dlatch
proc_dff
proc_memwr
proc_clean
# =============================================================================
# Basic optimization and flatten
# =============================================================================
opt_expr -keepdc
flatten
opt_expr
opt_clean
check
# =============================================================================
# General optimizations
# =============================================================================
opt -nodffe -nosdff
opt_expr
opt_merge -nomux
opt_muxtree
opt_reduce
opt_merge
opt_dff -nodffe -nosdff
opt_clean
opt_expr
# =============================================================================
# FSM extraction and mapping
# =============================================================================
fsm
fsm_detect
fsm_extract
fsm_opt
opt_clean
fsm_opt
fsm_recode
fsm_info
fsm_map
opt
opt_expr
opt_merge -nomux
opt_muxtree
opt_reduce
opt_merge
opt_dff
opt_clean
opt_expr
# =============================================================================
# Arithmetic and peephole optimizations
# =============================================================================
wreduce
peepopt
opt_clean
alumacc
share
opt
opt_expr
opt_merge -nomux
opt_muxtree
opt_reduce
opt_merge
opt_dff
opt_clean
opt_expr
# =============================================================================
# Memory optimizations
# =============================================================================
memory -nomap
opt_mem
opt_mem_priority
opt_mem_feedback
memory_bmux2rom
memory_dff
opt_clean
memory_share
opt_mem_widen
opt_clean
memory_collect
opt_clean
# =============================================================================
# Additional synthesis and full optimization
# =============================================================================
synth -flatten -extra-map [file join $base_dir files_needed/lcu_kogge_stone.v] -top {top_module} -run fine:check
opt -fast -full
opt_expr -full
opt_merge
opt_dff
opt_clean
memory_map
opt -full
opt_expr -full
opt_merge -nomux
opt_muxtree
opt_reduce -full
opt_merge
opt_share
opt_dff
opt_clean
opt_expr -full
# =============================================================================
# Techmapping to target cells
# =============================================================================
techmap -map +/techmap.v -map [file join $base_dir files_needed/lcu_kogge_stone.v]
opt -fast
opt_expr
opt_merge
opt_dff
opt_clean
# =============================================================================
# ABC optimization
# =============================================================================
abc -fast
opt -fast
opt_expr
opt_merge
opt_dff
opt_clean
# =============================================================================
# Cleanups and formal removal
# =============================================================================
delete */t:$print
chformal -remove
hierarchy -top {top_module}
opt -purge
opt_expr
opt_merge -nomux
opt_muxtree
opt_reduce
opt_merge
opt_dff
opt_clean -purge
opt_expr
# =============================================================================
# Specialized techmapping and logic extraction
# =============================================================================
techmap -map [file join $base_dir files_needed/cells_tristatebuf.v]
techmap
opt -fast -purge
opt_expr
opt_merge
opt_dff
opt_clean -purge
extract_fa
techmap -map [file join $base_dir files_needed/cells_adders.v]
techmap
opt -fast -purge
opt_expr
opt_merge
opt_dff
opt_clean -purge
techmap -map [file join $base_dir files_needed/cells_latch.v]
techmap
opt -fast -purge
opt_expr
opt_merge
opt_dff
opt_clean -purge
# =============================================================================
# DFF mapping and legalization
# =============================================================================
rename -wire
dfflibmap -dont_use OAI211_X1 -liberty {t_ech}
dfflegalize -cell $_DFF_P_ 01 -cell $_DFF_PN0_ 01 -cell $_DFF_PN1_ 01 -cell $_DFFSR_PNN_ 01 t:$_DFF* t:$_SDFF*
techmap
opt -purge
opt_expr
opt_merge -nomux
opt_muxtree
opt_reduce
opt_merge
opt_dff
opt_clean -purge
opt_expr
opt_muxtree
opt_reduce
opt_merge
opt_dff
opt_clean -purge
opt_expr
# =============================================================================
# Final ABC with constraints
# =============================================================================
abc -constr [file join $base_dir files_needed/sc_abc.constraints] -liberty {t_ech} -dont_use OAI211_X1 -dont_use CLKBUF_X1 -dont_use CLKBUF_X2 -dont_use CLKBUF_X3
# =============================================================================
# Post-mapping cleanup and output
# =============================================================================
clean -purge
setundef -zero
splitnets
clean -purge
hilomap -singleton -locell LOGIC0_X1 Z -hicell LOGIC1_X1 Z
insbuf -buf BUF_X1 A Z
clean -purge
echo off
write_verilog -noexpr -nohex -nodec {file_path}g
clean_zerowidth
stat -liberty {t_ech}
# setup.py
from setuptools import setup, find_packages
install_requires=[
"siliconcompiler",
"networkx",
"openai",
"psutil",
]
setup(
name="eda_tools", # 包名(安装后用这个名字导入)
version="0.1.0", # 版本号(后续更新可递增,如0.1.1)
packages=find_packages(), # 自动发现所有子包(这里会找到eda_tools/)
author="Your Name",
description="A set of EDA tools for Verilog analysis (including PPA)",
long_description=open("README.md").read() if __name__ == "__main__" else "",
long_description_content_type="text/markdown",
# 声明依赖库(你的代码需要哪些库才能运行)
install_requires=install_requires,
)
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -8,7 +8,7 @@ import argparse
import json
from pprint import pprint
from transformers import AutoTokenizer
from verl.utils.reward_score.codev_eval_toolkit.eval_codev import extract_verilog
from eda_tools.core import extract_verilog
def mk_prompt_r1_v1(question):
......
tools:
- class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
config:
sandbox_fusion_url: "http://0.0.0.0:8181/run_code"
sandbox_fusion_url: "http://127.0.0.1:8181/run_code"
num_workers: 32
enable_global_rate_limit: true
rate_limit: 32
......
......@@ -165,16 +165,24 @@ class SandboxFusionTool(BaseTool):
def execute_code(self, instance_id, code, timeout=30, language="python"):
logger.info(f"Execute code from SandboxFusionTool")
code = code.strip()
# preprocess markdown format
if code.startswith("```") and code.endswith("```"):
code = '\n'.join(code.split('\n')[1:-1])
result_status, metadata = _process_single_case(0, None, None, self.sandbox_fusion_url, code, timeout, language)
# we should always expect this since we don't have correct answer
# print('During code execution, metadata is', metadata)
if metadata["run_status"] == "Finished":
actual_output = metadata["stdout"] if metadata["stdout"] is not None else ""
actual_error = ("Error:\n" + metadata["stderr"]) if metadata["stderr"].strip() else ""
actual_error = ("Runtime Error:\n" + metadata["stderr"]) if metadata["stderr"].strip() else ""
logger.info(f"ID {instance_id} in sandbox fusion: actual_output is {actual_output}")
return actual_output + actual_error
elif metadata["compile_status"] == "Finished" and metadata["compile_stderr"]:
# todo: extract code from line number & original code
# print("Compile Error:\n" + metadata["compile_stderr"].strip())
return "Compile Error:\n" + metadata["compile_stderr"].strip()
else:
return "no stdout here"
return "No stdout here"
async def calc_reward(self, instance_id: str, **kwargs) -> str:
return self._instance_dict[instance_id]["reward"]
......
......@@ -21,9 +21,8 @@ import hydra
import numpy as np
import ray
os.environ["NCCL_DEBUG"] = "WARN"
print('NCCL_DEBUG is', os.environ['NCCL_DEBUG'])
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'
from pprint import pprint
......@@ -41,6 +40,28 @@ from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.trainer.main_ppo import create_rl_dataset
from torchdata.stateful_dataloader import StatefulDataLoader
# newly added, increase timeout
# import torch
# import torch.distributed as dist
# import traceback
# import datetime
# from verl.utils.device import get_device_id, get_device_name, get_nccl_backend
# print('torch.distributed.is_initialized is', torch.distributed.is_initialized())
# if not torch.distributed.is_initialized():
# rank = int(os.environ.get("RANK", 0))
# world_size = int(os.environ.get("WORLD_SIZE", 1))
# torch.distributed.init_process_group(backend=get_nccl_backend(), rank=rank, world_size=world_size)
# # dist.init_process_group(backend='gloo', timeout=datetime.timedelta(minutes=60))
# print("后端:", dist.get_backend())
# print("总进程数:", dist.get_world_size())
# print("当前进程 ID:", dist.get_rank())
# # print("当前组 ID:", dist.get_group_rank(dist.group.WORLD))
# # print("节点 ID:", dist.get_local_rank()) # 单节点内的进程序号(0,1,2...)
# print("是否为主进程:", dist.get_rank() == 0)
# print("超时时间:", dist.default_pg_timeout)
@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
......@@ -75,14 +96,18 @@ def main_task(config):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# print('torch.distributed.is_initialized is', dist.is_initialized())
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
# print('torch.distributed.is_initialized is', dist.is_initialized())
wg = RayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=ray_cls_with_init,
device_name=config.trainer.device,
)
# print('torch.distributed.is_initialized is', dist.is_initialized())
wg.init_model()
# print('torch.distributed.is_initialized is', dist.is_initialized())
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
# dataset = pd.read_parquet(config.data.path)
......
import re
from verl.utils.reward_score.codev_eval_toolkit.eval_codev import verify_one_sample, verify_one_sample_wrapper, extract_verilog
def compute_score_617795(solution_str, ground_truth, exceed_length=False):
response_pos = solution_str.find("<|im_start|>assistant")
if response_pos >= 0:
solution_str = solution_str[response_pos:]
else:
pass
extracted_answer = extract_verilog(solution_str)
def check_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
def check_partial_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
for i in range(1, 4):
if positions[i] == -1 and positions[i-1] >= 0:
positions[i] = positions[i-1] + 1
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
if not check_format(solution_str) or extracted_answer is None:
if exceed_length and check_partial_format(solution_str):
reward = 0.0
else:
reward = -1.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
if result["correct"] == True:
reward = 1.0
else:
reward = -0.5
return reward
def compute_score_618832(solution_str, ground_truth, exceed_length=False):
response_pos = solution_str.find("<|im_start|>assistant")
if response_pos >= 0:
solution_str = solution_str[response_pos:]
else:
pass
extracted_answer = extract_verilog(solution_str)
def check_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
def check_partial_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
for i in range(1, 4):
if positions[i] == -1 and positions[i-1] >= 0:
positions[i] = positions[i-1] + 1
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
if not check_format(solution_str) or extracted_answer is None:
if exceed_length and check_partial_format(solution_str):
reward = 0.0
else:
reward = -1.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
if result["correct"] == True:
reward = 3.0
else:
reward = -0.5
return reward
from eda_tools.core import verify_one_sample, ppa_one_sample, run_function_with_timeout, extract_verilog
def compute_score(solution_str, ground_truth, **kwargs):
......@@ -112,7 +40,9 @@ def compute_score(solution_str, ground_truth, **kwargs):
if not check_format(solution_str) or extracted_answer is None:
reward = 0.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
result = run_function_with_timeout(verify_one_sample, ground_truth, extracted_answer)
# 加上PPA的result
#######################################################################################
# print("result is", result)
if result["correct"] == True:
reward = 1.0
......@@ -148,7 +78,8 @@ def compute_score_wrapper(data_source, solution_str, ground_truth, extra_info, *
if __name__ == '__main__':
file = "/nfs_global/S/zhuyaoyu/projects/CodeV-o1/data/source/codev_dataset_165k_wo_module_head.jsonl"
# file = "/nfs_global/S/zhuyaoyu/projects/CodeV-o1/data/source/codev_dataset_165k_wo_module_head.jsonl"
file = "/nfs_global/datasets/codev/codev_dataset_165k_v3.jsonl"
import json
with open(file, "r") as f:
data = list(map(json.loads, f.read().strip().splitlines()))
......@@ -261,4 +192,11 @@ endmodule"""
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer>```verilog\nmodule my374_labl;\n reg temp;\nendmodule\n```</answer>"
example_ans = "module my374 lab1 ();\n\treg temp;\nendmodule"
\ No newline at end of file
example_ans = "module my374 lab1 ();\n\treg temp;\nendmodule"
# 正确
example_ans = data[2]["response"]
example_output = f"<think></think> <answer>\n```verilog\n{example_ans}```\n</answer>"
reward = compute_score(example_output, example_ans)
ppa = ppa_one_sample(example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}\n{sep}\n{ppa}")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment