from verl.utils.reward_score.codev_eval_toolkit.verify import eda_tools
import json
import re
import os
from tqdm.contrib.concurrent import process_map
from multiprocessing import Process, Queue
import psutil
import hashlib
import random
import platform

# # 根据不同系统导入不同的文件锁模块
# if platform.system() == 'Windows':
#     import msvcrt
# else:
#     import fcntl

# # 假设的锁文件路径
# LOCK_FILE_PATH = '.lock'


# def create_lock_file():
#     if not os.path.exists(LOCK_FILE_PATH):
#         with open(LOCK_FILE_PATH, 'w') as f:
#             pass


# def acquire_lock():
#     if platform.system() == 'Windows':
#         f = open(LOCK_FILE_PATH, 'r+')
#         msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1)
#         return f
#     else:
#         f = open(LOCK_FILE_PATH, 'r+')
#         fcntl.flock(f.fileno(), fcntl.LOCK_EX)
#         return f


# def release_lock(f):
#     if platform.system() == 'Windows':
#         msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
#     else:
#         fcntl.flock(f.fileno(), fcntl.LOCK_UN)
#     f.close()



def verify_one_sample(gold_code, dut_code, uid=None):
    uid = dut_code + str(random.randint(0,2147483647))
    uid = hashlib.md5(uid.encode("utf-8")).hexdigest()
    v = eda_tools(quiet=True)

    if not gold_code or not dut_code:
        return {"correct": False}

    try:
        gold_top = v.auto_top(gold_code)
        gate_top = v.auto_top(dut_code)
    except Exception as e:
        # exception in verification, gold code or dut code have syntax problems
        # print("Parse error:", e.args)
        return {"correct": False, "parse_error": e.args}

    gold_path, dut_path = f"./tmp/testcase/{uid}_gold.v", f"./tmp/testcase/{uid}_dut.v"
    test_path = f"./tmp/work/{uid}"
    
    try:
        if not os.path.exists("./tmp/testcase"):
            os.makedirs("./tmp/testcase", exist_ok=True)
        if not os.path.exists("./tmp/work"):
            os.makedirs("./tmp/work", exist_ok=True)
        if not os.path.exists(test_path):
            os.makedirs(test_path, exist_ok=True)
    finally:
        # release_lock(f)
        pass
    
    with open(gold_path, "w") as f:
        f.write(gold_code)
    with open(dut_path, "w") as f:
        f.write(dut_code)

    # 如果想生成testbench代码并运行，参考以下内容
    result = None
    try:
        equiv = v.equiv_with_testbench(
            gold_path,
            dut_path,
            gold_top,
            gate_top,
            test_path,
        )
    except Exception as e:
        # print("Test error:", e.args)
        result = {"correct": False, "test_error": e.args}
    finally:
        if os.path.exists(gold_path):
            os.remove(gold_path)
        if os.path.exists(dut_path):
            os.remove(dut_path)
        if os.path.exists(test_path):
            os.system(f"rm -r {test_path}")

    if result is None:
        result = {"correct": equiv[0], "error_rate": equiv[1], "detail": equiv[2]}
    return result


def kill_process_tree(pid):
    parent = psutil.Process(pid)
    children = parent.children(recursive=True)  # 获取所有子进程
    for child in children:
        child.terminate()  # 终止子进程
    parent.terminate()  # 终止父进程


def verify_one_sample_wrapper(args):
    def target(queue):
        result = verify_one_sample(*args)
        queue.put(result)

    queue = Queue()
    process = Process(target=target, args=(queue,))
    process.start()
    process.join(timeout=30)

    if process.is_alive():
        # 如果超时，终止进程
        kill_process_tree(process.pid)
        process.join()
        print("Function timed out!")
        return {"correct": False, "timeout": True}
    else:
        # 返回结果
        return queue.get()


def extract_verilog(verilog_code):
    """
    从 Verilog 代码中提取 module 声明部分（module_head）。
    """
    pattern = re.compile(r"```verilog\s*([\s\S]*?)\s*```")
    matches = re.findall(pattern, verilog_code)
    if matches:
        return matches[-1]  # 返回匹配的 module 声明
    return None


if __name__ == "__main__":

    for part in range(16):
        name = f"codev_dataset_165k_o1_part{part}"

        with open(f"data/evolve/{name}.jsonl", "r") as f:
            data_gold = list(map(json.loads, f.read().strip().splitlines()))
            data_gold = [extract_verilog(x["response"][0]["content"]) for x in data_gold]
        with open(f"results/evolve/sample/{name}.jsonl", "r") as f:
            data_dut = list(map(json.loads, f.read().strip().splitlines()))
            problem_ids = [x["problem_id"] for x in data_dut]
            data_dut = [extract_verilog(x["response"][0]["content"]) for x in data_dut]

        print(len(data_gold), len(data_dut), len(problem_ids))
        assert len(data_dut) % len(data_gold) == 0
        n_sample = len(data_dut) // len(data_gold)
        testcases = []
        for i, dut in enumerate(data_dut):
            gold = data_gold[i // n_sample]
            testcases.append((gold, dut, i))

        # testcases = testcases[:1000]

        if not os.path.exists("./tmp/testcase"):
            os.makedirs("./tmp/testcase")
        if not os.path.exists("./tmp/work"):
            os.makedirs("./tmp/work")

        # cpu_num = multiprocessing.cpu_count()
        cpu_num = 64
        # chunksize = max(len(testcases) // (cpu_num * 5), 1)
        chunksize = 1
        results = process_map(verify_one_sample_wrapper, testcases, max_workers=cpu_num, chunksize=chunksize)
        for i in range(len(results)):
            results[i]["problem_id"] = problem_ids[i]

        with open(f"results/evolve/eval/{name}.jsonl", "w") as f:
            f.write("\n".join(map(json.dumps, results)) + "\n")

        print(f"{name}.jsonl is processed!!!")
