from utils_vllm import vllm_chatcomplete, vllm_score
from utils import read_config, load_jsonl, save_jsonl
from utils_metric import group_results, score_pass_at_k
from utils_dataset import mk_critic_verify, get_score_token_id



if __name__ == "__main__":
    cfg = read_config()
    vllm_chatcomplete(
        cfg["critic"]["model_path"],
        cfg["critic"]["test"]["prompt_path"],
        cfg["critic"]["test"]["reason_result_path"],
        cfg["critic"]["test"]["sampling_params"]
    )

    score_token = get_score_token_id(args.model)

    reason_results = load_jsonl(cfg["critic"]["test"]["reason_result_path"])
    score_prompts = []
    for item in reason_results:
        item["messages"] += mk_critic_verify()
        score_prompts.append(item)

    save_jsonl(score_prompts, "test_score_prompt.jsonl")

    vllm_score(
        cfg["critic"]["model_path"],
        "test_score_prompt.jsonl",
        cfg["critic"]["test"]["score_result_path"],
        score_token
    )

    results = load_jsonl(cfg["critic"]["test"]["score_result_path"])
    groups = group_results(results, cfg["apps"])
    eval_results = [score_pass_at_k(groups, k, "critic") for k in range(1, 16)]
    save_jsonl(eval_results, cfg["critic"]["test"]["eval_result_path"])
    print(eval_results)
