import argparse
from pathlib import Path
import pprint

from codecritic.data.utils import mk_message
from codecritic.data.verify import JUDGE_PROMPT
from transformers import AutoTokenizer

from codecritic.data.code import extract_code, code_template
from codecritic.data.cov import COV_PROMPT
from codecritic.data.verify import get_score_token_id
from codecritic.utils.vllm import vllm_chatcomplete, vllm_score
from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.utils.metric import group_results, score_pass_at_k


def preprocess_test_item(item):
    question = item["messages"][0]["content"]
    answer = item["messages"][1]["content"]
    code = code_template.format(extract_code(answer))
    item["messages"] = mk_message(question, code)
    return item


def append_prompt(item, content):
    item["messages"].append({"role": "user", "content": content})
    return item


def run_sft_model(model_path, test_path, apps_path, reason_prompt=None):
    home_path = Path(model_path).parent
    result_dir = home_path / "eval"
    result_dir.mkdir(exist_ok=True)

    # preprocess prompt
    raw_test_dataset = load_jsonl(test_path)
    test_dataset = [preprocess_test_item(item) for item in raw_test_dataset]

    # reason
    if reason_prompt:
        test_dataset = [append_prompt(x, COV_PROMPT) for x in test_dataset]
        sampling_params = dict(n=1, temperature=0.0, max_tokens=2048)
        test_dataset = vllm_chatcomplete(model_path, test_dataset, sampling_params)

    # score
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    score_token = get_score_token_id(tokenizer)
    test_dataset = [append_prompt(x, JUDGE_PROMPT) for x in test_dataset]
    results = vllm_score(model_path, test_dataset, score_token)
    score_path = result_dir / "scores.jsonl"
    save_jsonl(results, score_path)

    # compute pass@k
    eval_result_path = result_dir / "passk.jsonl"
    # results = load_jsonl(score_path)
    groups = group_results(results, apps_path)
    eval_results = [score_pass_at_k(groups, k, home_path.stem) for k in range(1, 16)]
    save_jsonl(eval_results, eval_result_path)
    pprint.pp(eval_results)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--test", type=str)
    parser.add_argument("--apps", type=str)
    parser.add_argument("--reason", choices=["cov"])
    args = parser.parse_args()

    reason_prompts = {"cov": COV_PROMPT}
    reason_prompt = reason_prompts.get(args.reason, None)
    run_sft_model(args.model, args.test, args.apps, reason_prompt)
