from utils_vllm import vllm_score
from utils import read_config, load_jsonl, save_jsonl, extract_code
from utils_preference_dataset import code_template, mk_critic_qa, mk_critic_verify
from utils_metric import group_results, score_pass_at_k
from transformers import AutoTokenizer


def preprocess_test_item(item):
    question = item["messages"][0]["content"]
    answer = item["messages"][1]["content"]
    code = code_template.format(extract_code(answer))
    item["messages"] = mk_critic_qa(question, code) + mk_critic_verify()
    return item


if __name__ == "__main__":
    cfg = read_config()

    raw_test_dataset = load_jsonl(cfg["dataset"]["minimal_test_path"])
    test_dataset = [preprocess_test_item(item) for item in raw_test_dataset]
    save_jsonl(test_dataset, cfg["sftorm"]["test"]["prompt_path"])

    tokenizer = AutoTokenizer.from_pretrained(cfg["sftorm"]["model_path"])
    score_tokens = tokenizer.encode("Yes", add_special_tokens=False)
    assert len(score_tokens) == 1
    score_token = score_tokens[0]

    vllm_score(
        cfg["sftorm"]["model_path"],
        cfg["sftorm"]["test"]["prompt_path"],
        cfg["sftorm"]["test"]["score_result_path"],
        score_token
    )

    results = load_jsonl(cfg["sftorm"]["test"]["score_result_path"])
    groups = group_results(results, cfg["apps"])
    eval_results = [score_pass_at_k(groups, k, "sft-orm") for k in range(1, 16)]
    save_jsonl(eval_results, cfg["sftorm"]["test"]["eval_result_path"])
    print(eval_results)
