Commit 3f1e287d by root

add example for verilog tool call, might need tuning the prompt

parent c765c731
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
### 配置 ### 配置
TIR的主要配置在`examples/tir/run_sandbox_fusion.sh``examples/tir/sandbox_fusion_tool_config.yaml`里面。 TIR的主要配置在`examples/tir/sandbox_fusion_python_config.yaml`里面。
其中,由于是本地起的沙盒服务,沙盒IP取决于被分配到了哪个计算节点,因此我套了一层`examples/tir/sandbox_fusion_tool_config.yaml.template`,在执行脚本里面用 其中,由于是本地起的沙盒服务,沙盒IP取决于被分配到了哪个计算节点,因此我套了一层`examples/tir/sandbox_fusion_python_config.yaml.template`,在执行脚本里面用
```bash ```bash
TOOL_CONFIG_PATH=$CURR_DIR/examples/tir/sandbox_fusion_tool_config.yaml TOOL_CONFIG_PATH=$CURR_DIR/examples/tir/sandbox_fusion_python_config.yaml
envsubst < "$TOOL_CONFIG_PATH.template" > $TOOL_CONFIG_PATH envsubst < "$TOOL_CONFIG_PATH.template" > $TOOL_CONFIG_PATH
``` ```
生成实际的`examples/tir/sandbox_fusion_tool_config.yaml`这个配置文件的`tool_schema`部分会给加到system prompt里面去。 生成实际的`examples/tir/$TOOL_CONFIG_PATH.yaml`。在parquet数据预处理的时候每条加上`{"code_interpreter": {"create_kwargs": {'dummy': None}}}`(具体见`examples/data_preprocess/convert_eurus_tir.py`)后,这个配置文件的`tool_schema`部分会给加到system prompt里面去。
具体执行tool的代码应该在`verl/tools/sandbox_fusion_tools.py``async def execute`那边。 具体执行tool的代码应该在`verl/tools/sandbox_fusion_tools.py``async def execute`那边。
......
...@@ -21,7 +21,7 @@ def mk_prompt_r1_v1(question): ...@@ -21,7 +21,7 @@ def mk_prompt_r1_v1(question):
if pos >= 0: if pos >= 0:
question = question[pos + len('Now, try to write the corresponding verilog code based on the following content through the above guidelines:\n'):] question = question[pos + len('Now, try to write the corresponding verilog code based on the following content through the above guidelines:\n'):]
system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...) ... ``` </answer>.\n""" system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...); ... ``` </answer>.\n"""
user_prompt = question.strip() + "\n" user_prompt = question.strip() + "\n"
conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
return conversation return conversation
...@@ -34,49 +34,20 @@ def mk_prompt_r1_v1_1(question): ...@@ -34,49 +34,20 @@ def mk_prompt_r1_v1_1(question):
if pos >= 0: if pos >= 0:
question = question[pos + len('Now, try to write the corresponding verilog code based on the following content through the above guidelines:\n'):] question = question[pos + len('Now, try to write the corresponding verilog code based on the following content through the above guidelines:\n'):]
system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...) ... ``` </answer>.\n""" system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...); ... ``` </answer>.\n"""
user_prompt = question.strip() + "\n" user_prompt = question.strip() + "\n"
conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
return conversation return conversation
def mk_prompt_r1_v2(question):
# 格式要改
question = question.replace("Enclose your code with [BEGIN] and [DONE]. Only output the code snippet\nand do NOT output anything else.\n\n", "")
prompt = f"""<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process.\n\nPlease structure your response into two main sections: Thought and Solution.\n\nIn the Thought section, detail your reasoning process using the specified format:\n```\n<think>\n{{thought with steps separated with \"\n\n\"}}\n</think>\n```\nEach step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps.\n\nIn the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows:\n```\n<answer>\n{{final formatted, precise, and clear solution}}\n</answer>\n```\nNow, try to write the corresponding verilog code based on the following content through the above guidelines:
{question}
<|im_end|>
<|im_start|>assistant
<think>"""
conversation = [{"role": "user", "content": prompt}]
return conversation
def mk_prompt_r1_v3(question):
question = question.replace("Enclose your code with [BEGIN] and [DONE]. Only output the code snippet\nand do NOT output anything else.\n\n", "")
pos = question.find("The module head of the code should be:")
if pos >= 0:
question = question[:pos]
prompt = f"""<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n<|im_end|>\n<|im_start|>user\nNow, try to write the corresponding verilog code based on the following content:\n{question}\n<|im_end|>\n<|im_start|>assistant\n<think>"""
conversation = [{"role": "user", "content": prompt}]
return conversation
def mk_prompt_r1(question): def mk_prompt_r1(question):
question = question.replace("Enclose your code with [BEGIN] and [DONE]. Only output the code snippet\nand do NOT output anything else.\n\n", "") question = question.replace("Enclose your code with [BEGIN] and [DONE]. Only output the code snippet\nand do NOT output anything else.\n\n", "")
# pos = question.find("The module head of the code should be:") # pos = question.find("The module head of the code should be:")
# if pos >= 0: # if pos >= 0:
# question = question[:pos] # question = question[:pos]
system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...) ... ``` </answer>.\n""" system_prompt = """You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...); ... ``` </answer>.\n"""
user_prompt = question.strip() + "\n" user_prompt = question.strip() + "\n"
# prompt = f"""<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...) ... ``` </answer>.\n<|im_end|>\n<|im_start|>user\n{question}\n<|im_end|>\n<|im_start|>assistant\n<think>""" # prompt = f"""<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog\n module top_module(in, out, ...); ... ``` </answer>.\n<|im_end|>\n<|im_start|>user\n{question}\n<|im_end|>\n<|im_start|>assistant\n<think>"""
conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] conversation = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
return conversation return conversation
...@@ -85,13 +56,13 @@ if __name__ == '__main__': ...@@ -85,13 +56,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='data/codev/v1/3.1k_r1_filtered') parser.add_argument('--local_dir', default='data/codev/v1/3.1k_r1_filtered')
parser.add_argument('--hdfs_dir', default=None) parser.add_argument('--hdfs_dir', default=None)
# parser.add_argument('--data_path', default='/nfs_global/S/lvhanqi/codev_data/decontamination_sft_model_filter_4.8k_and_qwen_32b_correct1234_system_prompt_codev_dataset_v3.jsonl')
parser.add_argument('--data_path', default='/nfs_global/S/lvhanqi/codev_data/decontamination_sft_model_filter_0320_correct_synthesizable_r1_system_prompt_filter_codev_dataset_165k_v3.jsonl') parser.add_argument('--data_path', default='/nfs_global/S/lvhanqi/codev_data/decontamination_sft_model_filter_0320_correct_synthesizable_r1_system_prompt_filter_codev_dataset_165k_v3.jsonl')
parser.add_argument('--tokenizer_path', default='/share/collab/codemodel/models/Qwen2.5-Coder-7B') parser.add_argument('--tokenizer_path', default='/share/collab/codemodel/models/Qwen2.5-Coder-7B')
parser.add_argument('--train_size', type=int, default=15000) parser.add_argument('--train_size', type=int, default=15000)
parser.add_argument('--test_size', type=int, default=984) parser.add_argument('--test_size', type=int, default=984)
parser.add_argument('--save_jsonl', action='store_true', help='Save dataset as jsonl files') parser.add_argument('--save_jsonl', action='store_true', help='Save dataset as jsonl files')
parser.add_argument('--gt', type=str, default=['gt'], choices=['gt', 'r1', 'double'], help='Choose ground_truth or r1 response or both as ground truth') parser.add_argument('--gt', type=str, default=['gt'], choices=['gt', 'r1', 'double'], help='Choose ground_truth or r1 response or both as ground truth')
parser.add_argument('--tool', action='store_true', help='Whether to prompt LLMs to use tools')
# continuous_reward is moved to training cfg # continuous_reward is moved to training cfg
# parser.add_argument('--continuous_reward', action='store_true', help='Save dataset as jsonl files') # parser.add_argument('--continuous_reward', action='store_true', help='Save dataset as jsonl files')
# parser.add_argument('--template_type', type=str, default='base') # parser.add_argument('--template_type', type=str, default='base')
...@@ -142,15 +113,16 @@ if __name__ == '__main__': ...@@ -142,15 +113,16 @@ if __name__ == '__main__':
else: else:
ground_truth = example["response"] ground_truth = example["response"]
question = make_question(example["question"]) question = make_question(example["question"])
# if args.continuous_reward:
# ground_truth = {"answer": ground_truth, "reward_mode": "continuous"}
if args.gt == 'both': if args.gt == 'both':
ground_truth = {"answer": ground_truth, "r1_answer": extract_verilog(example["r1_response"]["content"])} ground_truth = {"answer": ground_truth, "r1_answer": extract_verilog(example["r1_response"]["content"])}
elif args.gt == 'r1': elif args.gt == 'r1':
ground_truth = extract_verilog(example["r1_response"]["content"]) ground_truth = extract_verilog(example["r1_response"]["content"])
# pprint(ground_truth)
# exit(0) extra_info = {"split": split, "index": idx}
if args.tool:
extra_info["tools_kwargs"] = {"code_interpreter": {"create_kwargs": {'dummy': None}}}
data = { data = {
"data_source": "codev", "data_source": "codev",
"prompt": question, "prompt": question,
...@@ -159,10 +131,7 @@ if __name__ == '__main__': ...@@ -159,10 +131,7 @@ if __name__ == '__main__':
"style": "rule", "style": "rule",
"ground_truth": ground_truth "ground_truth": ground_truth
}, },
"extra_info": { "extra_info": extra_info
'split': split,
'index': idx,
}
} }
return data return data
return process_fn return process_fn
......
tools:
- class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
config:
sandbox_fusion_url: "http://10.21.0.12:8181/run_code"
num_workers: 32
enable_global_rate_limit: true
rate_limit: 32
default_timeout: 30
default_language: "verilog"
tool_schema:
type: "function"
function:
name: "code_interpreter"
description: "A code execution tool."
parameters:
type: "object"
properties:
code:
type: "string"
description: "The verilog code to execute."
required: ["code"]
\ No newline at end of file
tools:
- class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
config:
sandbox_fusion_url: "$SANDBOX_URL"
num_workers: 32
enable_global_rate_limit: true
rate_limit: 32
default_timeout: 30
default_language: "verilog"
tool_schema:
type: "function"
function:
name: "code_interpreter"
description: "A code execution tool."
parameters:
type: "object"
properties:
code:
type: "string"
description: "The verilog code to execute."
required: ["code"]
\ No newline at end of file
set -x
export VERL_LOGGING_LEVEL=INFO
python3 -X faulthandler -u -m verl.trainer.main_ppo \
reward_model.sandbox_fusion.url=$SANDBOX_URL \
reward_model.sandbox_fusion.max_concurrent=128 \
reward_model.reward_manager=prime \
algorithm.adv_estimator=grpo \
data.train_files=$CURR_DIR/data/codev/v1/16k_r1_filtered/train.parquet \
data.val_files=$CURR_DIR/data/codev/v1/16k_r1_filtered/test.parquet \
data.train_batch_size=64 \
data.max_prompt_length=2048 \
data.max_response_length=4096 \
data.return_raw_chat=True \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=/nfs_global/models/Qwen2.5-Coder-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_liger=True \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_turns=3 \
actor_rollout_ref.rollout.multi_turn.tool_config_path=$TOOL_CFG_PATH \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=16 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_sandbox_fusion' \
trainer.experiment_name='codev_sandbox_fusion' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=1 \
trainer.total_epochs=15 $@
# actor_rollout_ref.rollout.multi_turn.max_turns 后面改成 max_assistant_turns 了
\ No newline at end of file
...@@ -30,7 +30,7 @@ python3 -X faulthandler -u -m verl.trainer.main_ppo \ ...@@ -30,7 +30,7 @@ python3 -X faulthandler -u -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.multi_turn.enable=True \ actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_turns=3 \ actor_rollout_ref.rollout.multi_turn.max_turns=3 \
actor_rollout_ref.rollout.multi_turn.tool_config_path=$CURR_DIR/examples/tir/sandbox_fusion_tool_config.yaml \ actor_rollout_ref.rollout.multi_turn.tool_config_path=$TOOL_CFG_PATH \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=16 \ actor_rollout_ref.rollout.n=16 \
algorithm.use_kl_in_reward=False \ algorithm.use_kl_in_reward=False \
......
...@@ -10,4 +10,7 @@ ...@@ -10,4 +10,7 @@
# python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_filter_error_rate_l_0.2_from_87k.jsonl --local_dir data/codev/v1/err_l0.2_16k_r1_filtered --train_size 16364 --test_size 300 # python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_filter_error_rate_l_0.2_from_87k.jsonl --local_dir data/codev/v1/err_l0.2_16k_r1_filtered --train_size 16364 --test_size 300
# python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_filter_error_rate_l_0.2_from_87k.jsonl --local_dir data/codev/v1/err_l0.2_16k_r1_filtered_double_gt --gt double --train_size 16364 --test_size 300 # python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_filter_error_rate_l_0.2_from_87k.jsonl --local_dir data/codev/v1/err_l0.2_16k_r1_filtered_double_gt --gt double --train_size 16364 --test_size 300
# python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_qwen7b32b_filter_gt_r1_error_rate_e_0.5_from_87k.jsonl --local_dir data/codev/v1/qwen7b32b_filter_gt_r1_error_rate_e_0.5_7.4k --gt r1 --train_size 7204 --test_size 200 # python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_qwen7b32b_filter_gt_r1_error_rate_e_0.5_from_87k.jsonl --local_dir data/codev/v1/qwen7b32b_filter_gt_r1_error_rate_e_0.5_7.4k --gt r1 --train_size 7204 --test_size 200
python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_87k_correct1234_filter_qwen7b32b_data.jsonl --local_dir data/codev/v1/qwen7b32b_filter_gt_r1_14k --gt r1 --train_size 14654 --test_size 300 # python examples/data_preprocess/codev.py --data_path /nfs_global/S/lvhanqi/codev_data/sft_model_87k_correct1234_filter_qwen7b32b_data.jsonl --local_dir data/codev/v1/qwen7b32b_filter_gt_r1_14k --gt r1 --train_size 14654 --test_size 300
\ No newline at end of file
# tencent cloud
python examples/data_preprocess/codev.py --tool --tokenizer_path /nfs_global/models/Qwen2.5-Coder-7B-Instruct/ --data_path /nfs_global/datasets/codev/codev_r1_rl_16k.jsonl --local_dir data/codev/v1/16k_r1_filtered --train_size 15691 --test_size 300
\ No newline at end of file
...@@ -87,6 +87,13 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No ...@@ -87,6 +87,13 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No
from . import search_r1_like_qa_em from . import search_r1_like_qa_em
res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
elif data_source in ['codev']:
# if sandbox_fusion_url:
# from . import sandbox_fusion
# res = sandbox_fusion.compute_score(sandbox_fusion_url, concurrent_semaphore, solution_str, ground_truth, continuous=True)
# else:
from . import codev
res = codev.compute_score(solution_str, ground_truth)
else: else:
raise NotImplementedError(f"Reward function is not implemented for {data_source=}") raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
except Exception as e: except Exception as e:
......
import re
from verl.utils.reward_score.codev_eval_toolkit.eval_codev import verify_one_sample, verify_one_sample_wrapper, extract_verilog
def compute_score_617795(solution_str, ground_truth, exceed_length=False):
response_pos = solution_str.find("<|im_start|>assistant")
if response_pos >= 0:
solution_str = solution_str[response_pos:]
else:
pass
extracted_answer = extract_verilog(solution_str)
def check_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
def check_partial_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
for i in range(1, 4):
if positions[i] == -1 and positions[i-1] >= 0:
positions[i] = positions[i-1] + 1
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
if not check_format(solution_str) or extracted_answer is None:
if exceed_length and check_partial_format(solution_str):
reward = 0.0
else:
reward = -1.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
if result["correct"] == True:
reward = 1.0
else:
reward = -0.5
return reward
def compute_score_618832(solution_str, ground_truth, exceed_length=False):
response_pos = solution_str.find("<|im_start|>assistant")
if response_pos >= 0:
solution_str = solution_str[response_pos:]
else:
pass
extracted_answer = extract_verilog(solution_str)
def check_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
def check_partial_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
positions = [output.find(tag) for tag in tags]
for i in range(1, 4):
if positions[i] == -1 and positions[i-1] >= 0:
positions[i] = positions[i-1] + 1
return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
if not check_format(solution_str) or extracted_answer is None:
if exceed_length and check_partial_format(solution_str):
reward = 0.0
else:
reward = -1.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
if result["correct"] == True:
reward = 3.0
else:
reward = -0.5
return reward
def compute_score(solution_str, ground_truth, **kwargs):
reward_mode = kwargs.get('reward_mode', 'discrete')
assert reward_mode in ['discrete', 'continuous'], "mode should be either 'discrete' or 'continuous'"
# print("Reward mode is:", reward_mode)
if reward_mode == 'continuous':
err_threshold = kwargs.get('err_threshold', None)
reward_mapping = kwargs.get('reward_mapping', 'threshold')
assert reward_mode != 'continuous' or err_threshold is not None, "err_threshold should be given when using continuous reward!"
assert reward_mapping in ['threshold', 'zero'], "reward_mapping should be either 'threshold' or 'zero'"
gt_keys = kwargs.get('gt_keys', None)
if gt_keys is not None:
assert isinstance(ground_truth, dict), "ground_truth should be a dict when gt_keys is given"
gts = [ground_truth[key] for key in gt_keys]
else:
assert isinstance(ground_truth, str), "ground_truth should be a string when gt_keys is not given"
gts = [ground_truth]
# model_output= re.sub(r'^.*?<\|im_start\|>assistant', '<|im_start|>assistant', model_output, flags=re.DOTALL,count = 1)
response_pos = solution_str.find("<|im_start|>assistant")
if response_pos >= 0:
solution_str = solution_str[response_pos:]
else:
# 这样应该是题目过长?但题目长度被卡了,应该没影响
pass
def check_format(output):
tags = ["<think>", "</think>", "<answer>", "</answer>"]
tag_count = [output.count(tag) for tag in tags]
positions = [output.find(tag) for tag in tags]
return min(tag_count) == max(tag_count) == 1 and positions[0] < positions[1] < positions[2] < positions[3]
def calc_reward(solution_str, ground_truth):
extracted_answer = extract_verilog(solution_str)
if not check_format(solution_str) or extracted_answer is None:
reward = 0.0
else:
result = verify_one_sample_wrapper((ground_truth, extracted_answer))
# print("result is", result)
if result["correct"] == True:
reward = 1.0
else:
if reward_mode == 'discrete':
reward = 0.0
else:
# GRPO对比单题用error_rate应该问题不大,别的算法不好说
if 'error_rate' in result and result['error_rate'] <= err_threshold:
reward = 1 - result['error_rate'] if reward_mapping == 'threshold' else err_threshold - result['error_rate']
# print('Error rate is', result['error_rate'])
else:
reward = 0.0
# if "test_error" in result:
# print("=============test error=============")
# print(result["test_error"])
# print("=============extracted_answer=============")
# print(extracted_answer)
# print("=============ground_truth=============")
# print(ground_truth)
# print(reward_mode)
# print("Reward is", reward)
return reward
rewards = [calc_reward(solution_str, gt) for gt in gts]
reward = max(rewards)
return reward
def compute_score_wrapper(data_source, solution_str, ground_truth, extra_info, **kwargs):
return compute_score(solution_str, ground_truth, **kwargs)
if __name__ == '__main__':
file = "/nfs_global/S/zhuyaoyu/projects/CodeV-o1/data/source/codev_dataset_165k_wo_module_head.jsonl"
import json
with open(file, "r") as f:
data = list(map(json.loads, f.read().strip().splitlines()))
sep = "============================================"
# 正确
example_ans = data[0]["response"]
example_output = f"<think></think> <answer>\n```verilog\n{example_ans}```\n</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
# 格式错误
example_output = f"<think> <answer></think> ```verilog\n{example_ans}```</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
# 答案错误
example_output = f"<think> </think> <answer>\n```verilog\n```\n</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
example_output = """
<|im_start|>system
You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog
module top_module(in, out, ...) ... ``` </answer>.
<|im_end|>
<|im_start|>user
Design a signed adder that adds two 8-bit 2's complement numbers, a[7:0] and b[7:0]. These numbers are added to produce s[7:0]. Additionally, compute whether a signed overflow has occurred during the addition. The overflow condition should be detected when both operands have the same sign but the result differs in sign. This Verilog module, named `signed_adder`, has the interface designed as follows:
| Signal Name | Direction | Width | Description |
|-------------|-----------|-------|-----------------------------------------------|
| `a` | Input | 8 | First 8-bit signed 2's complement number |
| `b` | Input | 8 | Second 8-bit signed 2's complement number |
| `s` | Output | 8 | 8-bit signed result of the addition |
| `overflow` | Output | 1 | High if a signed overflow has occurred |
<|im_end|>
<|im_start|>assistant
<think>
To design a signed adder that adds two 8-bit 2's complement numbers and detects overflow, we need to follow these steps:
1. **Add the two 8-bit numbers**: Use a simple adder to get the sum.
2. **Detect overflow**: Check if both operands have the same sign but the result has a different sign. This indicates an overflow.
The overflow condition can be detected using the following logic:
- If both `a[7]` and `b[7]` are 0 (both are positive) and `s[7]` is 1 (result is negative), an overflow has occurred.
- If both `a[7]` and `b[7]` are 1 (both are negative) and `s[7]` is 0 (result is positive), an overflow has occurred.
Let's write the Verilog code based on this logic.
</think>
<answer> ```verilog
module signed_adder (
input [7:0] a,
input [7:0] b,
output reg [7:0] s,
output reg overflow
);
always @(*) begin
// Perform the addition
s = a + b;
// Determine overflow
overflow = (a[7] == b[7]) & (a[7] != s[7]);
end
endmodule
``` </answer><|im_end|>"""
example_ans = """
module signed_adder(input [7:0] a,
input [7:0] b,
output reg [7:0] s,
output reg overflow
);
// Perform the addition
assign s = a + b;
// Detect overflow
assign overflow = (a[7] == b[7]) && (a[7] != s[7]);
endmodule"""
# 正确
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
# with open("/nfs_global/S/zhuyaoyu/projects/logicRL/outputs/all_wrong_problems.jsonl", "r") as f:
# data = list(map(json.loads, f.read().strip().splitlines()))
# for item in data[::32]:
# response, gt = item['responses'], item['ground_truth']
# reward = compute_score(response, gt)
# # print(f"{sep}\n{response}\n{sep}\n{reward}")
# print(reward)
# # print(item['problem_id'])
# # print(response)
# 格式错误
example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer><think></think>\n```verilog\n```\n</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
# 格式错误
example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> </think> <answer>\n```verilog\n```\n</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
# 答案错误
example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer>\n```verilog\n```\n</answer>"
reward = compute_score(example_output, example_ans)
print(f"{sep}\n{example_output}\n{sep}\n{reward}")
example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer>```verilog\nmodule my374_labl;\n reg temp;\nendmodule\n```</answer>"
example_ans = "module my374 lab1 ();\n\treg temp;\nendmodule"
\ No newline at end of file
from verl.utils.reward_score.codev_eval_toolkit.verify import eda_tools
import json
import re
import os
from tqdm.contrib.concurrent import process_map
from multiprocessing import Process, Queue
import psutil
import hashlib
import random
import platform
# # 根据不同系统导入不同的文件锁模块
# if platform.system() == 'Windows':
# import msvcrt
# else:
# import fcntl
# # 假设的锁文件路径
# LOCK_FILE_PATH = '.lock'
# def create_lock_file():
# if not os.path.exists(LOCK_FILE_PATH):
# with open(LOCK_FILE_PATH, 'w') as f:
# pass
# def acquire_lock():
# if platform.system() == 'Windows':
# f = open(LOCK_FILE_PATH, 'r+')
# msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1)
# return f
# else:
# f = open(LOCK_FILE_PATH, 'r+')
# fcntl.flock(f.fileno(), fcntl.LOCK_EX)
# return f
# def release_lock(f):
# if platform.system() == 'Windows':
# msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
# else:
# fcntl.flock(f.fileno(), fcntl.LOCK_UN)
# f.close()
def verify_one_sample(gold_code, dut_code, uid=None):
uid = dut_code + str(random.randint(0,2147483647))
uid = hashlib.md5(uid.encode("utf-8")).hexdigest()
v = eda_tools(quiet=True)
# v = eda_tools(quiet=False)
if not gold_code or not dut_code:
return {"correct": False}
try:
gold_top = v.auto_top(gold_code)
gate_top = v.auto_top(dut_code)
except Exception as e:
# exception in verification, gold code or dut code have syntax problems
# print("Parse error:", e.args)
return {"correct": False, "parse_error": e.args}
gold_path, dut_path = f"./tmp/testcase/{uid}_gold.v", f"./tmp/testcase/{uid}_dut.v"
test_path = f"./tmp/work/{uid}"
try:
if not os.path.exists("./tmp/testcase"):
os.makedirs("./tmp/testcase", exist_ok=True)
if not os.path.exists("./tmp/work"):
os.makedirs("./tmp/work", exist_ok=True)
if not os.path.exists(test_path):
os.makedirs(test_path, exist_ok=True)
finally:
# release_lock(f)
pass
with open(gold_path, "w") as f:
f.write(gold_code)
with open(dut_path, "w") as f:
f.write(dut_code)
# 如果想生成testbench代码并运行,参考以下内容
result = None
try:
equiv = v.equiv_with_testbench(
gold_path,
dut_path,
gold_top,
gate_top,
test_path,
)
except Exception as e:
# print("Test error:", e.args)
result = {"correct": False, "test_error": e.args}
finally:
if os.path.exists(gold_path):
os.remove(gold_path)
if os.path.exists(dut_path):
os.remove(dut_path)
if os.path.exists(test_path):
os.system(f"rm -r {test_path}")
if result is None:
result = {"correct": equiv[0], "error_rate": equiv[1], "detail": equiv[2]}
return result
def kill_process_tree(pid):
parent = psutil.Process(pid)
children = parent.children(recursive=True) # 获取所有子进程
for child in children:
child.terminate() # 终止子进程
parent.terminate() # 终止父进程
def verify_one_sample_wrapper(args):
def target(queue):
result = verify_one_sample(*args)
queue.put(result)
queue = Queue()
process = Process(target=target, args=(queue,))
process.start()
process.join(timeout=30)
if process.is_alive():
# 如果超时,终止进程
kill_process_tree(process.pid)
process.join()
print("Function timed out!")
return {"correct": False, "timeout": True}
else:
# 返回结果
return queue.get()
def extract_verilog(verilog_code):
"""
从 Verilog 代码中提取 module 声明部分(module_head)。
"""
pattern = re.compile(r"```verilog\s*([\s\S]*?)\s*```")
matches = re.findall(pattern, verilog_code)
if matches:
return matches[-1] # 返回匹配的 module 声明
return None
if __name__ == "__main__":
for part in range(16):
name = f"codev_dataset_165k_o1_part{part}"
with open(f"data/evolve/{name}.jsonl", "r") as f:
data_gold = list(map(json.loads, f.read().strip().splitlines()))
data_gold = [extract_verilog(x["response"][0]["content"]) for x in data_gold]
with open(f"results/evolve/sample/{name}.jsonl", "r") as f:
data_dut = list(map(json.loads, f.read().strip().splitlines()))
problem_ids = [x["problem_id"] for x in data_dut]
data_dut = [extract_verilog(x["response"][0]["content"]) for x in data_dut]
print(len(data_gold), len(data_dut), len(problem_ids))
assert len(data_dut) % len(data_gold) == 0
n_sample = len(data_dut) // len(data_gold)
testcases = []
for i, dut in enumerate(data_dut):
gold = data_gold[i // n_sample]
testcases.append((gold, dut, i))
# testcases = testcases[:1000]
if not os.path.exists("./tmp/testcase"):
os.makedirs("./tmp/testcase")
if not os.path.exists("./tmp/work"):
os.makedirs("./tmp/work")
# cpu_num = multiprocessing.cpu_count()
cpu_num = 64
# chunksize = max(len(testcases) // (cpu_num * 5), 1)
chunksize = 1
results = process_map(verify_one_sample_wrapper, testcases, max_workers=cpu_num, chunksize=chunksize)
for i in range(len(results)):
results[i]["problem_id"] = problem_ids[i]
with open(f"results/evolve/eval/{name}.jsonl", "w") as f:
f.write("\n".join(map(json.dumps, results)) + "\n")
print(f"{name}.jsonl is processed!!!")
import datetime
from itertools import combinations, product
import json
import math
import os
import random
import re
import shutil
import subprocess
import networkx as nx
from openai import OpenAI
from siliconcompiler import Chip # import python package
from siliconcompiler.targets import (
freepdk45_demo,
asap7_demo,
skywater130_demo,
) # import predefined technology and flow target
def llm_request(prompt, temperature=0.5):
api_key = os.getenv("tencent_key")
base_url = "https://api.lkeap.cloud.tencent.com/v1"
model = "deepseek-v3"
messages = [{"role": "user", "content": prompt}]
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
return response.choices[0].message.content
def extract_verilog_code(file_content):
# 去除注释
note_pattern = r"(//[^\n]*|/\*[\s\S]*?\*/)"
file_content = re.sub(note_pattern, "", file_content)
file_content = re.sub(r"(?:\s*?\n)+", "\n", file_content)
# 匹配编译指令中的define
define_pattern = r"`define\b\s+\b([a-zA-Z_][a-zA-Z0-9_$]*|\\[!-~]+?(?:\s|$))\b.*\n"
# TODO 匹配更多编译指令
# TODO 匹配 task 和 function
# task_function_pattern = r"\btask\b\s+([a-zA-Z_][a-zA-Z0-9_$]*|\\[!-~]+(?:\s|$))[\s\S]*?\bendtask\b|\bfunction\b[\s\S]*?\bendfunction\b"
# 匹配 module 到 endmodule 之间的内容,并提取模块名
module_pattern = r"\bmodule\s+([a-zA-Z_][a-zA-Z0-9_$]*|\\[!-~]+?(?:\s|$))\s*(?:\#\s*\([\s\S]*?\)\s*)?\((?:(?!\bmodule\b).)*?\)\s*;(?:(?!\bmodule\b).)*?\bendmodule\b"
# 使用字典来存储,保留每个匹配的对象中最后出现的一个
item_dict = {}
item_order = []
for match in re.finditer(
f"{module_pattern}|{define_pattern}", file_content, re.DOTALL
):
item_name = match.group(1)
if item_name not in item_order:
item_order.append(item_name)
item_dict[item_name] = match.group(0)
extracted = "\n".join([item_dict[item] for item in item_order])
return extracted
class eda_tools:
def __init__(
self,
golden_suffix="_gold",
gate_suffix="_gate",
use_directed_tests=False,
random_seq_steps=1000,
random_seq_num=100,
quiet=False,
):
"""
simulator: 仿真器,主要支持iverilog,verilator有bug
golden_suffix: 参考设计在testbench中的实例名后缀,默认_gold
gate_suffix: 待测设计在testbench中的实例名后缀,默认_gate
use_directed_tests: 是否使用定向测试,如果为True,则使用LLM生成的定向测试,否则使用随机测试,LLM生成效果较差,默认False
random_seq_steps: 随机测试每个序列的长度,越长越准确,运行时间越长,默认1000
random_seq_num: 随机测试的序列数,越多越准确,运行时间越长,默认100
quiet: 是否打印输出,默认False
"""
self.golden_suffix = golden_suffix
self.gate_suffix = gate_suffix
self.use_directed_tests = use_directed_tests
self.random_seq_steps = random_seq_steps
self.random_seq_num = random_seq_num
self.quiet = quiet
def auto_top(self, verilog_code):
"""
自动找到verilog代码中的顶层模块,当前实现为找到最大的调用子树的根节点,当两个调用字数大小相同时,选择字典序最小的
输入:
verilog_code: verilog代码字符串
输出:
top_module: 顶层模块名
"""
instance_graph = nx.DiGraph()
note_pattern = r"(//[^\n]*|/\*[\s\S]*?\*/)"
new_code = re.sub(note_pattern, "", verilog_code)
new_code = re.sub(r"(?:\s*?\n)+", "\n", new_code)
module_def_pattern = r"(module\s+)([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*\#\s*\([\s\S]*?\))?(\s*(?:\([^;]*\))?\s*;)([\s\S]*?)?(endmodule)"
module_defs = re.findall(module_def_pattern, new_code, re.DOTALL)
if not module_defs:
raise Exception("No module found in auto_top().")
module_names = [m[1] for m in module_defs]
instance_graph.add_nodes_from(module_names)
# 匹配 module 到 endmodule 之间的内容,并提取模块名
for mod in module_defs:
this_mod_name = mod[1]
this_mod_body = mod[4]
for submod in module_names:
if submod != this_mod_name:
module_instance_pattern = rf"({re.escape(submod)})(\s)(\s*\#\s*\([\s\S]*?\))?([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*(?:\([^;]*\))?\s*;)"
module_instances = re.findall(
module_instance_pattern, this_mod_body, re.DOTALL
)
if module_instances:
instance_graph.add_edge(this_mod_name, submod)
instance_tree_size = {}
for n in instance_graph.nodes:
if instance_graph.in_degree(n) == 0:
instance_tree_size[n] = nx.descendants(instance_graph, n)
top_module = max(instance_tree_size, key=instance_tree_size.get)
return top_module
def process_verilog(self, verilog_code, suffix):
"""读verilog代码,在所有模块定义和调用后面加上suffix,用于区分gold和gate设计"""
note_pattern = r"(//[^\n]*|/\*[\s\S]*?\*/)"
new_code = re.sub(note_pattern, "", verilog_code)
new_code = re.sub(r"(?:\s*?\n)+", "\n", new_code)
module_def_pattern = r"(module\s+)([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*\#\s*\([\s\S]*?\))?(\s*(?:\([^;]*\))?\s*;)([\s\S]*?)?(endmodule)"
module_defs = re.findall(module_def_pattern, new_code, re.DOTALL)
module_names = [m[1] for m in module_defs]
for submod in module_names:
module_instance_pattern = rf"({submod})(\s+)(\#\s*\([\s\S]*?\)\s*)?([a-zA-Z_][a-zA-Z0-9_\$]*|\\[!-~]+?(?=\s))(\s*(?:\([^;]*\))?\s*;)"
new_code = re.sub(module_instance_pattern, rf"\1{suffix}\2\3\4\5", new_code)
new_code = re.sub(module_def_pattern, rf"\1\2{suffix}\3\4\5\6", new_code)
return new_code
def extract_golden_ports(self, golden_path, golden_top, timeout=60):
"""
根据yosys的结果,提取golden模块的输入输出端口、时钟端口、复位端口。
golden_path: 参考设计的路径
golden_top: 参考设计的顶层模块名
输出:
为一个元组(input_port_width, output_port_width, clock_port_polarity, reset_port_polarity_sync)
input_port_width: 输入端口名、位宽
output_port_width: 输出端口名、位宽
clock_port_polarity: 时钟端口名、上升沿/下降沿触发
reset_port_polarity_sync: 复位端口名、高低电平有效、同步/异步复位
"""
golden_top = golden_top.lstrip("\\")
yosys_script = f"read_verilog {golden_path}; prep -top {golden_top} -flatten; opt_dff -nodffe; json -compat-int; exec -- echo 'Happy new year~';"
yosys_result = subprocess.run(
["yosys", "-p", yosys_script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
)
if yosys_result.stderr:
raise Exception(yosys_result.stderr.decode("utf-8"))
yosys_output = yosys_result.stdout.decode("utf-8")
yosys_json_text = re.search(
r'(\{\n\s+"creator":[\s\S]*\})\n+[\d]+\. Executing command',
yosys_output,
re.DOTALL,
).group(1)
yosys_json = json.loads(yosys_json_text)
ports_ids_dict = {}
input_port_width = set()
output_port_width = set()
if yosys_json["modules"] == {}:
raise Exception("No module found in yosys output after synthesis.")
for port_name in yosys_json["modules"][golden_top]["ports"]:
direction = yosys_json["modules"][golden_top]["ports"][port_name][
"direction"
]
bits = yosys_json["modules"][golden_top]["ports"][port_name]["bits"]
width = len(bits)
ports_ids_dict[port_name] = bits
if direction == "input":
input_port_width.add((port_name, width))
if direction == "output":
output_port_width.add((port_name, width))
clock_port_polarity = set()
reset_port_polarity_sync = set()
def find_single_port(port_id, ports_ids_dict):
if len(port_id) != 1:
raise Exception("Only support single port id now.")
for port_name, bits in ports_ids_dict.items():
if len(bits) == 1 and bits[0] == port_id[0]:
return port_name
elif len(bits) > 1 and port_id[0] in bits:
return f"{port_name}[{bits.index(port_id[0])}]"
else:
return None
for cell_id in yosys_json["modules"][golden_top]["cells"]:
cell = yosys_json["modules"][golden_top]["cells"][cell_id]
for reg_ports in cell["connections"]:
if reg_ports == "CLK":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["CLK_POLARITY"]
clock_port_polarity.add((port_name, polarity))
break
match cell["type"]:
case "$adff" | "$adffe" | "$adlatch":
for reg_ports in cell["connections"]:
if reg_ports == "ARST":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["ARST_POLARITY"]
sync = False
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
break
case "$sdff" | "$sdffe" | "$sdffce":
for reg_ports in cell["connections"]:
if reg_ports == "SRST":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"]["SRST_POLARITY"]
sync = True
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
break
case "$dffsr" | "$dffsre" | "$dlatchsr" | "$sr":
for reg_ports in cell["connections"]:
if reg_ports == "SET" or reg_ports == "CLR":
port_id = cell["connections"][reg_ports]
port_name = find_single_port(port_id, ports_ids_dict)
if port_name:
polarity = cell["parameters"][f"{reg_ports}_POLARITY"]
sync = False
reset_port_polarity_sync.add(
(port_name, polarity, sync)
)
break
case "$dlatch" | "$ff" | "$dff" | "$dffe" | "aldff" | "$aldffe":
pass
case _:
pass
if not self.quiet:
print(f"Input ports:")
for port, width in input_port_width:
print(f" {port}: {width}")
print(f"Output ports:")
for port, width in output_port_width:
print(f" {port}: {width}")
print(f"Clock ports:")
for port, polarity in clock_port_polarity:
print(f" {port}: {'posedge' if polarity else 'negedge'}")
print(f"Reset ports:")
for port, polarity, sync in reset_port_polarity_sync:
print(
f" {port}: {'high' if polarity else 'low'} {'sync' if sync else 'async'}"
)
return (
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
)
def generate_testbench(
self,
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
golden_top,
gate_top,
):
"""
根据golden模块和gate模块的输入输出端口、时钟端口、复位端口,生成testbench代码。返回值中不会带有定向测试的具体输入,定向测试的输入在write_code_testbench中添加。
输入:
input_port_width: 输入端口名、位宽,为一个集合,其中的元素为(port_name, width)
output_port_width: 输出端口名、位宽,为一个集合,其中的元素为(port_name, width)
clock_port_polarity:
时钟端口名、上升沿/下降沿触发
为一个集合,其中的元素为(port_name, polarity)
port_name是端口名,字符串
polarity是时钟信号的极性,1表示上升沿触发,0表示下降沿触发
reset_port_polarity_sync:
复位信号的端口名、高低电平有效、同步/异步复位
为一个集合,其中的元素为(port_name, polarity, sync)
port_name是端口名,字符串
polarity是复位信号的极性,1表示高电平有效,0表示低电平有效
sync是复位信号的同步异步,True表示同步复位,False表示异步复位
golden_top: 参考设计的顶层模块名
gate_top: 待测设计的顶层模块名
输出:不包括定向测试的具体输入值的testbench代码
"""
reset_port_names = set([p[0] for p in reset_port_polarity_sync])
if len(clock_port_polarity) > 1:
raise Exception(
"Multiple clock ports or multiple triggering edge detected, currently not supported."
)
clock_port_name = (
list(clock_port_polarity)[0][0] if clock_port_polarity else None
)
clock_port_edge = (
list(clock_port_polarity)[0][1] if clock_port_polarity else None
)
input_port_names = [p[0] for p in input_port_width]
output_port_names = [p[0] for p in output_port_width]
# 生成输入信号定义
input_defs = "\n ".join(
[f"reg [{width-1}:0] {port}_in ;" for port, width in input_port_width]
)
gold_output_defs = "\n ".join(
[
f"wire [{width-1}:0] {port}{self.golden_suffix} ;"
for port, width in output_port_width
]
)
gate_output_defs = "\n ".join(
[
f"wire [{width-1}:0] {port}{self.gate_suffix} ;"
for port, width in output_port_width
]
)
# 生成trigger信号,trigger信号为1时表示golden和gate输出不一致
trigger_assign = (
"\n always @(*) begin\n #5; trigger = ~( "
+ " & ".join(
[
f"{port}{self.golden_suffix} === {port}{self.gate_suffix}"
for port in output_port_names
]
+ ["1'b1"]
)
+ " );\n end\n"
)
# 实例化gold和gate模块的端口赋值语句
gold_port_mappings = ",\n ".join(
[f".{port}( {port}_in )" for port in input_port_names]
+ [f".{port}( {port}{self.golden_suffix} )" for port in output_port_names]
)
gate_port_mappings = ",\n ".join(
[f".{port}( {port}_in )" for port in input_port_names]
+ [f".{port}( {port}{self.gate_suffix} )" for port in output_port_names]
)
# 生成随机化输入信号的task
randomize_inputs_lines = "\n ".join(
[
f"{port}_in = {{{', '.join(['$random(seed)']*math.ceil(width/32))}}};"
for port, width in input_port_width
if port not in [clock_port_name] + list(reset_port_names)
]
)
randomize_inputs_task = f"""// task to generate random inputs
task randomize_inputs;
begin
{randomize_inputs_lines}
end
endtask
"""
# 根据复位信号的极性和同步异步,进行组合,生成不同组合下复位的task
# 按照reset端口名进行分组
grouped = {}
for port, polarity, sync in reset_port_polarity_sync:
if port not in grouped:
grouped[port] = []
grouped[port].append((port, polarity, sync))
# 生成所有可能的组合,一个port name在一个组合中最多只能出现一次
all_reset_combinations = []
for r in range(1, len(grouped) + 1):
for ports in combinations(grouped.keys(), r):
for polarities_syncs in product(*[grouped[port] for port in ports]):
all_reset_combinations.append(list(polarities_syncs))
# 根据每种组合生成复位信号的task,同步复位的task中,复位信号赋值后需要clock完成一次上升和下降沿
reset_task_list = []
for i, reset_comb in enumerate(all_reset_combinations):
sync_reset_lines = []
async_reset_lines = []
unset_lines = []
for port, polarity, sync in reset_comb:
if sync:
sync_reset_lines.append(f"{port}_in = {polarity};")
else:
async_reset_lines.append(f"{port}_in = {polarity};")
unset_lines.append(f"{port}_in = {0 if polarity == 1 else 1};")
reset_lines = (
(
"\n ".join(sync_reset_lines)
+ "\n # 10; toggle_clock; # 10; toggle_clock;\n "
+ "\n ".join(unset_lines)
)
if sync_reset_lines
else "" + "\n ".join(async_reset_lines + unset_lines)
)
reset_task = f"""task reset_{i};
begin
{reset_lines}
end
endtask
"""
reset_task_list.append(reset_task)
# 生成定向测试的task,定向测试的赋值由用户改写或LLM生成
directed_tests_task = f"""// Task for directed test. The inputs should be able to activate all functionalities in the golden design, and checks whether the gate design and the golden design are equivalent.
task directed_tests;
begin
// [TODO] directed tests here.
{'# 10; toggle_clock; # 10; toggle_clock;' if clock_port_name else ''}
end
endtask
"""
# 生成翻转时钟信号的task
toggle_clock_task = f"""// Task to toggle {clock_port_name}_in
task toggle_clock;
begin
{clock_port_name}_in = ~{clock_port_name}_in ;
end
endtask
"""
count_errors_task = f"""// Task to count errors
task count_errors;
begin
if (trigger === 1'b1) begin
num_errors = num_errors + 1;
end
num_all = num_all + 1;
end
endtask
"""
# 生成随机复位信号的task
random_reset_lines = "\n ".join(
[f"{port}_in = $random(seed);" for port in reset_port_names]
)
random_reset_task = f"""// Task for random reset
task random_reset;
begin
{random_reset_lines}
end
endtask
"""
# 生成 initial block
initial_block_lines = [
"// initial block for random tests and targed tests",
"initial begin",
' if (!$value$plusargs("seed=%d", seed)) seed = 0;',
f' if (!$value$plusargs("outerLoopNum=%d", outerLoopNum)) outerLoopNum = {self.random_seq_num};',
f' if (!$value$plusargs("innerLoopNum=%d", innerLoopNum)) innerLoopNum = {self.random_seq_steps};',
(
f" {clock_port_name}_in = {0 if clock_port_edge else 1};"
if clock_port_name
else ""
),
f" repeat (outerLoopNum) begin",
" random_reset;" if reset_port_names else "",
" #100; count_errors;",
f" repeat (innerLoopNum) begin",
" #100; randomize_inputs;",
" #100; toggle_clock;" if clock_port_name else "",
" #100; count_errors;",
" end",
" end",
]
if reset_port_names:
initial_block_lines.append(" #100;")
for i in range(len(reset_task_list)):
initial_block_lines.append(
f" repeat (outerLoopNum) begin",
)
initial_block_lines.append(f" reset_{i};")
initial_block_lines.append(f" #100; count_errors;")
initial_block_lines.append(
f" repeat (innerLoopNum) begin",
)
initial_block_lines.append(f" #100; randomize_inputs;")
(
initial_block_lines.append(f" #100; toggle_clock;")
if clock_port_name
else ""
)
initial_block_lines.append(f" #100; count_errors;")
initial_block_lines.append(f" end")
initial_block_lines.append(f" end")
if self.use_directed_tests:
if reset_port_names:
for i in range(len(reset_task_list)):
initial_block_lines.append(f" reset_{i};")
initial_block_lines.append(" #100;")
initial_block_lines.append(" directed_tests;")
initial_block_lines.append(" #100;")
else:
initial_block_lines.append(" directed_tests;")
initial_block_lines += [
' $display("Number of all tests: %d", num_all);',
' $display("Number of errors: %d", num_errors);',
' $display("Error rate: %.8f", num_errors/num_all);',
" if (num_errors == 0) begin",
' $display("All tests passed.");',
" end",
" $finish;",
"end",
]
initial_block = "\n ".join(initial_block_lines)
# 生成监测输出信号的 always block
monitor_block = f"""always @(trigger) begin
if (trigger === 1'b1) begin
$error("trigger signal is 1, which is not allowed!");
$finish;
end
end
"""
# 生成完整的 testbench 代码
testbench_code = f"""
module testbench;
{input_defs}
{gold_output_defs}
{gate_output_defs}
reg trigger;
real num_all = 0;
real num_errors = 0;
integer seed;
integer outerLoopNum;
integer innerLoopNum;
{golden_top}{self.golden_suffix} gold (
{gold_port_mappings}
);
{gate_top}{self.gate_suffix} gate (
{gate_port_mappings}
);
{trigger_assign}
{toggle_clock_task if clock_port_name else ""}
{''.join(reset_task_list) if reset_port_names else ""}
{random_reset_task if reset_port_names else ""}
{randomize_inputs_task}
{directed_tests_task if self.use_directed_tests else ""}
{count_errors_task}
{initial_block}
endmodule
"""
return testbench_code
def generate_directed_test(
self,
golden_code,
tb_code,
):
"""
生成定向测试的输入并插入testbench代码中,返回插入后的testbench代码
输入:
golden_code: 参考设计的代码
tb_code: testbench的代码
输出:
返回值为(renamed_golden_code, tb_module_code)
"""
renamed_golden_code = self.process_verilog(golden_code, self.golden_suffix)
fim_code = renamed_golden_code + tb_code
if self.use_directed_tests:
print("Generating directed tests with LLM...") if not self.quiet else None
fim_prompt = f"""
Given the following Verilog design code and its testbench:
```verilog
{fim_code}
```
Please complete the directed test inputs in the testbench. Only provide the code that replaces the "[TODO] directed tests here." section, wrapped in a ```verilog``` code block. Do not include any other content or explanations in your response.
Example of expected response format:
```verilog
// Your directed test code here
```
"""
try:
response = llm_request(fim_prompt)
directed_inputs = re.findall(
r"```verilog(.*?)```", response, re.DOTALL
)[0]
tb_module_code = tb_code.replace(
"// [TODO] directed tests here.", directed_inputs
)
except Exception as e:
print(e)
print("Failed to generate directed tests with LLM.")
return renamed_golden_code, tb_module_code
def write_code_testbench(
self,
golden_path,
gate_path,
golden_top,
gate_top,
tb_dir,
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
):
"""
写gold设计和gate设计和testbench到文件,定向测试的具体输入在本函数中添加。
输入:
golden_path: 参考设计的path
gate_path: 待测设计的path
golden_top: 参考设计的顶层模块名
gate_top: 待测设计的顶层模块名
tb_dir: 生成的testbench所在的路径,包括Makefile,verilator在这个路径下运行
输出:
返回值为(renamed_golden_code, renamed_gate_code, tb_module_code)
"""
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
with open(golden_path, "r") as f:
golden_code = f.read()
print("Processing golden code...") if not self.quiet else None
renamed_golden_code = self.process_verilog(golden_code, self.golden_suffix)
if gate_path is not None:
with open(gate_path, "r") as f:
gate_code = f.read()
print("Processing gate code...") if not self.quiet else None
renamed_gate_code = self.process_verilog(gate_code, self.gate_suffix)
with open(os.path.join(tb_dir, "gate.v"), "w") as f:
f.write(renamed_gate_code)
tb_module_code = self.generate_testbench(
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
golden_top,
gate_top,
)
fim_code = renamed_golden_code + tb_module_code
if self.use_directed_tests:
print("Generating directed tests with LLM...") if not self.quiet else None
fim_prompt = f"""
Given the following Verilog design code and its testbench:
```verilog
{fim_code}
```
Please complete the directed test inputs in the testbench. Only provide the code that replaces the "[TODO] directed tests here." section, wrapped in a ```verilog``` code block. Do not include any other content or explanations in your response.
Example of expected response format:
```verilog
// Your directed test code here
```
"""
try:
response = llm_request(fim_prompt)
directed_inputs = re.findall(r"```verilog(.*?)```", response, re.DOTALL)
tb_module_code = tb_module_code.replace(
"// [TODO] directed tests here.", directed_inputs
)
except Exception as e:
print(e)
print("Failed to generate directed tests with LLM.")
with open(os.path.join(tb_dir, "tb.v"), "w") as f:
f.write(renamed_golden_code)
f.write("\n")
f.write(tb_module_code)
return renamed_golden_code, renamed_gate_code, tb_module_code
def equiv_with_testbench(
self,
golden_path,
gate_path,
golden_top,
gate_top,
tb_dir,
seed=0,
outerLoopNum=None,
innerLoopNum=None,
timeout=60,
):
"""
一个wrapper,用于生成testbench并运行,输出测试结果,如果运行的输出中有"all tests passed."则测试通过,否则测试失败
输入:
golden_path: 参考设计的path
gate_path: 待测设计的path
golden_top: 参考设计的顶层模块名
gate_top: 待测设计的顶层模块名
tb_dir: 生成的testbench所在的路径,包括Makefile,verilator在这个路径下运行
seed: 控制testbench测试时的随机种子,默认0,和之前对齐
outerLoopNum: 控制testbench测试时的外层循环次数,默认None,此时采用self.random_seq_num和之前对齐
innerLoopNum: 控制testbench测试时的内层循环次数,默认None,此时采用self.random_seq_steps和之前对齐
timeout: 仿真超时时间,默认60s
输出:
返回值为True或False,表示测试通过或失败
"""
if outerLoopNum == None:
outerLoopNum = self.random_seq_num
if innerLoopNum == None:
innerLoopNum = self.random_seq_steps
(
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
) = self.extract_golden_ports(golden_path, golden_top)
self.write_code_testbench(
golden_path=golden_path,
gate_path=gate_path,
golden_top=golden_top,
gate_top=gate_top,
tb_dir=tb_dir,
input_port_width=input_port_width,
output_port_width=output_port_width,
clock_port_polarity=clock_port_polarity,
reset_port_polarity_sync=reset_port_polarity_sync,
)
command = f"iverilog -g2012 -o {os.path.join(tb_dir,'tb.vvp')} -s testbench {os.path.join(tb_dir,'*.v')} && {os.path.join(tb_dir,f'tb.vvp +seed={seed} +outerLoopNum={outerLoopNum} +innerLoopNum={innerLoopNum}')}"
res = subprocess.run(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
)
error_rate_pattern = r"Error rate:\s*(\d+\.\d+)\n"
print(res.stdout.decode("utf-8")) if not self.quiet else None
if re.search(error_rate_pattern, res.stdout.decode("utf-8")):
error_rate = float(
re.search(error_rate_pattern, res.stdout.decode("utf-8")).group(1)
)
else:
error_rate = 1.0
# TODO 实现得比较丑陋
if "All tests passed." in res.stdout.decode("utf-8"):
print("Test passed!") if not self.quiet else None
return (
True,
error_rate,
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
)
else:
print("Test failed!") if not self.quiet else None
return (
False,
error_rate,
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
)
def synthesis_with_siliconcompiler(
self,
job_name: str,
rtl_paths: list,
top: str,
clk: str = None,
tech: str = "freepdk45",
timeout=60,
build_dir="./work/build",
cache_dir="./siliconcompiler",
) -> dict:
"""
用来综合一个设计,返回综合结果,包括cell_area, peak_power, arrival_time
输入:
job_name: 任务名,用于区分不同的任务
rtl_paths: rtl代码的路径,是一个list
top: 顶层模块名
tech: 技术库,支持freepdk45, asap7, skywater130,默认freepdk45
rm_build: 是否删除build目录,默认True
输出:
返回值为一个字典,包括cell_area, peak_power, arrival_time
"""
chip = Chip(top) # create chip object
chip.set("option", "builddir", build_dir)
chip.set("option", "cachedir", cache_dir)
chip.set("option", "jobname", job_name)
chip.set("option", "clean", True)
chip.set("option", "loglevel", "critical" if self.quiet else "info")
chip.set("option", "timeout", timeout)
for path in rtl_paths:
chip.input(path)
if clk:
chip.clock(clk, period=0.1) # define clock speed of design
match tech:
case "freepdk45":
chip.use(freepdk45_demo)
case "asap7":
chip.use(asap7_demo)
case "skywater130":
chip.use(skywater130_demo)
case _:
raise ValueError(f"Unsupported technology {tech}")
chip.set("option", "flow", "synflow")
chip.set("option", "remote", False) # run remote in the cloud
chip.run() # run compilation of design and target
cellarea = chip.get("metric", "cellarea", step="timing", index=0)
peakpower = chip.get("metric", "peakpower", step="timing", index=0)
workdir = chip.getworkdir(step="timing", index=0)
with open(os.path.join(workdir, "reports/unconstrained.rpt"), "r") as f:
rpt = f.read()
at_pattern = r"^\s+(\d*\.?\d*)\s+data arrival time"
try:
arrival_time = float(re.search(at_pattern, rpt, re.MULTILINE).group(1))
except:
arrival_time = None
ppa = {
"cell_area": cellarea,
"peak_power": peakpower,
"arrival_time": arrival_time,
}
return ppa
class myLogger:
def __init__(self):
self.log = []
def info(self, info_content):
self.log.append([str(datetime.datetime.now()), "INFO", info_content])
def debug(self, debug_content):
self.log.append([str(datetime.datetime.now()), "DEBUG", debug_content])
def output(self, level):
match level:
case "info":
lines = [l for l in self.log if l[1] == "INFO"]
text = "\n".join([" - ".join(l) for l in lines])
return text
case "debug":
text = "\n".join([" - ".join(l) for l in self.log])
return text
case _:
raise Exception("Unsupported. Only support info and debug.")
def main():
eda = eda_tools(quiet=True)
with open("./temp/gold.v", "r") as f:
gold_code = f.read()
gold_top = eda.auto_top(gold_code)
with open("./temp/gate.v", "r") as f:
gate_code = f.read()
gate_top = eda.auto_top(gate_code)
(
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
) = eda.extract_golden_ports("./temp/gold.v", gold_top)
tb = eda.generate_testbench(
input_port_width,
output_port_width,
clock_port_polarity,
reset_port_polarity_sync,
gold_top,
gate_top,
)
with open("./temp/tb.v", "w") as f:
f.write(eda.process_verilog(gold_code, eda.golden_suffix))
f.write("\n")
f.write(eda.process_verilog(gate_code, eda.gate_suffix))
f.write("\n")
f.write(tb)
equiv, error_rate, _, _, _, _ = eda.equiv_with_testbench(
"./temp/gold.v",
"./temp/gate.v",
gold_top,
gate_top,
"./temp/testbench",
seed=0,
outerLoopNum=100,
innerLoopNum=1000,
)
print(equiv, error_rate)
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment