import argparse
from pathlib import Path
import pprint
from utils_vllm import vllm_score
from utils import load_jsonl, save_jsonl, extract_code, code_template
from utils_dataset import mk_critic_qa, mk_critic_verify, get_score_token_id
from utils_metric import group_results, score_pass_at_k


def preprocess_test_item(item):
    question = item["messages"][0]["content"]
    answer = item["messages"][1]["content"]
    code = code_template.format(extract_code(answer))
    item["messages"] = mk_critic_qa(question, code) + mk_critic_verify()
    return item


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--test", type=str)
    parser.add_argument("--apps", type=str)
    args = parser.parse_args()

    home_path = Path(args.model).parent
    result_dir = home_path / "eval"
    result_dir.mkdir(exist_ok=True)

    # preprocess prompt
    prompt_path = result_dir / "prompt.jsonl"
    raw_test_dataset = load_jsonl(args.test)
    test_dataset = [preprocess_test_item(item) for item in raw_test_dataset]
    save_jsonl(test_dataset, prompt_path)

    # score
    score_path = result_dir / "scores.jsonl"
    score_token = get_score_token_id(args.model)
    vllm_score(args.model, prompt_path, score_path, score_token)

    # compute pass@k
    eval_result_path = result_dir / "passk.jsonl"
    results = load_jsonl(score_path)
    groups = group_results(results, args.apps)
    eval_results = [score_pass_at_k(groups, k, "sft-orm") for k in range(1, 16)]
    save_jsonl(eval_results, eval_result_path)
    pprint.pp(eval_results)
