import re
from verl.utils.reward_score.codev_eval_toolkit.eval_codev import verify_one_sample, verify_one_sample_wrapper, extract_verilog


def compute_score_617795(solution_str, ground_truth, exceed_length=False):
    response_pos = solution_str.find("<|im_start|>assistant")
    if response_pos >= 0:
        solution_str = solution_str[response_pos:]
    else:
        pass
    extracted_answer = extract_verilog(solution_str)

    def check_format(output):
        tags = ["<think>", "</think>", "<answer>", "</answer>"]
        positions = [output.find(tag) for tag in tags]
        return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]

    def check_partial_format(output):
        tags = ["<think>", "</think>", "<answer>", "</answer>"]
        positions = [output.find(tag) for tag in tags]
        for i in range(1, 4):
            if positions[i] == -1 and positions[i-1] >= 0:
                positions[i] = positions[i-1] + 1
        return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
    
    if not check_format(solution_str) or extracted_answer is None:
        if exceed_length and check_partial_format(solution_str):
            reward = 0.0
        else:
            reward = -1.0
    else:
        result = verify_one_sample_wrapper((ground_truth, extracted_answer))
        if result["correct"] == True:
            reward = 1.0
        else:
            reward = -0.5
    return reward


def compute_score_618832(solution_str, ground_truth, exceed_length=False):
    
    response_pos = solution_str.find("<|im_start|>assistant")
    if response_pos >= 0:
        solution_str = solution_str[response_pos:]
    else:
        pass
    extracted_answer = extract_verilog(solution_str)

    def check_format(output):
        tags = ["<think>", "</think>", "<answer>", "</answer>"]
        positions = [output.find(tag) for tag in tags]
        return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]

    def check_partial_format(output):
        tags = ["<think>", "</think>", "<answer>", "</answer>"]
        positions = [output.find(tag) for tag in tags]
        for i in range(1, 4):
            if positions[i] == -1 and positions[i-1] >= 0:
                positions[i] = positions[i-1] + 1
        return min(positions) >= 0 and positions[0] < positions[1] < positions[2] < positions[3]
    
    if not check_format(solution_str) or extracted_answer is None:
        if exceed_length and check_partial_format(solution_str):
            reward = 0.0
        else:
            reward = -1.0
    else:
        result = verify_one_sample_wrapper((ground_truth, extracted_answer))
        if result["correct"] == True:
            reward = 3.0
        else:
            reward = -0.5

    return reward


def compute_score(solution_str, ground_truth, **kwargs):
    reward_mode = kwargs.get('reward_mode', 'discrete')
    err_threshold = kwargs.get('err_threshold', None)
    reward_mapping = kwargs.get('reward_mapping', None)
    # if isinstance(ground_truth, dict):
    #     reward_mode = ground_truth.get("reward_mode", "discrete")
    #     err_threshold = ground_truth.get("err_threshold", 1.0)
    #     ground_truth = ground_truth.get("answer", "")
    # else:
    #     reward_mode = "discrete"
    assert reward_mode in ['discrete', 'continuous'], "mode should be either 'discrete' or 'continuous'"
    assert reward_mode != 'continuous' or err_threshold is not None, "err_threshold should be given when using continuous reward!"
    assert reward_mapping in ['threshold', 'zero'], "reward_mapping should be either 'threshold' or 'zero'"
    # print("Reward mode is:", reward_mode)
    # model_output= re.sub(r'^.*?<\|im_start\|>assistant', '<|im_start|>assistant', model_output, flags=re.DOTALL,count = 1)
    # print("compute_score for codev is called !!!")
    # print("=============solution_str=============")
    # print(solution_str)
    # print("=============ground_truth=============")
    # print(ground_truth)
    
    response_pos = solution_str.find("<|im_start|>assistant")
    if response_pos >= 0:
        solution_str = solution_str[response_pos:]
    else:
        # 这样应该是题目过长？但题目长度被卡了，应该没影响
        pass
    extracted_answer = extract_verilog(solution_str)

    def check_format(output):
        tags = ["<think>", "</think>", "<answer>", "</answer>"]
        tag_count = [output.count(tag) for tag in tags]
        positions = [output.find(tag) for tag in tags]
        return min(tag_count) == max(tag_count) == 1 and positions[0] < positions[1] < positions[2] < positions[3]

    if not check_format(solution_str) or extracted_answer is None:
        reward = 0.0
    else:
        result = verify_one_sample_wrapper((ground_truth, extracted_answer))
        # print("result is", result)
        if result["correct"] == True:
            reward = 1.0
        else:
            if reward_mode == 'discrete':
                reward = 0.0
            else:
                # GRPO对比单题用error_rate应该问题不大，别的算法不好说
                if 'error_rate' in result and result['error_rate'] <= err_threshold:
                    reward = 1 - result['error_rate'] if reward_mapping == 'threshold' else 1 - err_threshold - result['error_rate']
                    # print('Error rate is', result['error_rate'])
                else:
                    reward = 0.0
        
        # if "test_error" in result:
        #     print("=============test error=============")
        #     print(result["test_error"])
        #     print("=============extracted_answer=============")
        #     print(extracted_answer)
        #     print("=============ground_truth=============")
        #     print(ground_truth)
        # print(reward_mode)
        # print("Reward is", reward)

    return reward


def compute_score_wrapper(data_source, solution_str, ground_truth, extra_info, **kwargs):
    return compute_score(solution_str, ground_truth, **kwargs)


if __name__ == '__main__':
    file = "/nfs_global/S/zhuyaoyu/projects/CodeV-o1/data/source/codev_dataset_165k_wo_module_head.jsonl"
    import json
    with open(file, "r") as f:
        data = list(map(json.loads, f.read().strip().splitlines()))
    
    sep = "============================================"
    # 正确
    example_ans = data[0]["response"]
    example_output = f"<think></think>  <answer>\n```verilog\n{example_ans}```\n</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    # 格式错误
    example_output = f"<think> <answer></think> ```verilog\n{example_ans}```</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    # 答案错误
    example_output = f"<think> </think> <answer>\n```verilog\n```\n</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    example_output = """
<|im_start|>system
You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.  Now the user asks you to write verilog code. After thinking, when you finally reach a conclusion, enclose the final verilog code in ```verilog ``` within <answer> </answer> tags. i.e., <answer> ```verilog
 module top_module(in, out, ...) ... ``` </answer>.
<|im_end|>
<|im_start|>user
Design a signed adder that adds two 8-bit 2's complement numbers, a[7:0] and b[7:0]. These numbers are added to produce s[7:0]. Additionally, compute whether a signed overflow has occurred during the addition. The overflow condition should be detected when both operands have the same sign but the result differs in sign. This Verilog module, named `signed_adder`, has the interface designed as follows:

| Signal Name | Direction | Width | Description                                   |
|-------------|-----------|-------|-----------------------------------------------|
| `a`         | Input     | 8     | First 8-bit signed 2's complement number      |
| `b`         | Input     | 8     | Second 8-bit signed 2's complement number     |
| `s`         | Output    | 8     | 8-bit signed result of the addition           |
| `overflow`  | Output    | 1     | High if a signed overflow has occurred        |

<|im_end|>
<|im_start|>assistant
<think> 
To design a signed adder that adds two 8-bit 2's complement numbers and detects overflow, we need to follow these steps:

1. **Add the two 8-bit numbers**: Use a simple adder to get the sum.
2. **Detect overflow**: Check if both operands have the same sign but the result has a different sign. This indicates an overflow.

The overflow condition can be detected using the following logic:
- If both `a[7]` and `b[7]` are 0 (both are positive) and `s[7]` is 1 (result is negative), an overflow has occurred.
- If both `a[7]` and `b[7]` are 1 (both are negative) and `s[7]` is 0 (result is positive), an overflow has occurred.

Let's write the Verilog code based on this logic.
</think>

<answer> ```verilog
module signed_adder (
    input [7:0] a,
    input [7:0] b,
    output reg [7:0] s,
    output reg overflow
);

always @(*) begin
    // Perform the addition
    s = a + b;
    
    // Determine overflow
    overflow = (a[7] == b[7]) & (a[7] != s[7]);
end

endmodule
``` </answer><|im_end|>"""
    example_ans = """
module signed_adder(input [7:0] a,
    input [7:0] b,
    output reg [7:0] s,
    output reg overflow
);
    // Perform the addition
    assign s = a + b;

    // Detect overflow
    assign overflow = (a[7] == b[7]) && (a[7] != s[7]);
endmodule"""
    # 正确
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    # with open("/nfs_global/S/zhuyaoyu/projects/logicRL/outputs/all_wrong_problems.jsonl", "r") as f:
    #     data = list(map(json.loads, f.read().strip().splitlines()))
    
    # for item in data[::32]:
    #     response, gt = item['responses'], item['ground_truth']
    #     reward = compute_score(response, gt)
    #     # print(f"{sep}\n{response}\n{sep}\n{reward}")
    #     print(reward)
    #     # print(item['problem_id'])
    #     # print(response)

    # 格式错误
    example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer><think></think>\n```verilog\n```\n</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    # 格式错误
    example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> </think> <answer>\n```verilog\n```\n</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    # 答案错误
    example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer>\n```verilog\n```\n</answer>"
    reward = compute_score(example_output, example_ans)
    print(f"{sep}\n{example_output}\n{sep}\n{reward}")

    example_output = f"<|im_start|>system\nxxx\n<|im_end|>\n<|im_start|>user\nyyy\n<|im_end|>\n<|im_start|>assistant\n<think> </think> <answer>```verilog\nmodule my374_labl;\n reg temp;\nendmodule\n```</answer>"
    example_ans = "module my374 lab1 ();\n\treg temp;\nendmodule"