import argparse
import json
from tqdm import tqdm
import requests
from transformers import AutoTokenizer
import pprint
from pathlib import Path

from codecritic.data.code import code_template, extract_code
from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.utils.metric import group_results, score_pass_at_k


def get_rewards_from_server(server_url: str, messages: list[str]):
    """
    Gets reward scores from the API server.
    """
    headers = {"Content-Type": "application/json"}
    payload = {"query": messages}
    response = requests.post(server_url, json=payload, headers=headers)
    rewards = json.loads(response.text)["rewards"]
    return rewards


def test_reward_model(server_url, item, tokenizer):
    response = item["messages"][-1]["content"]
    code = code_template.format(extract_code(response))
    item["messages"][-1]["content"] = code

    query = tokenizer.apply_chat_template(item["messages"], tokenize=False)
    score = get_rewards_from_server(server_url, [query])[0]

    return {
        "problem_id": item["problem_id"],
        "messages": item["messages"],
        "eval_result": item["eval_result"],
        "score": score,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--test", type=str)
    parser.add_argument("--apps", type=str)
    args = parser.parse_args()

    home_path = Path(args.model).parent
    result_dir = home_path / "eval"
    result_dir.mkdir(exist_ok=True)

    # compute score
    test_dataset = load_jsonl(args.test)
    server_url = "http://0.0.0.0:5000/get_reward"
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    results = [test_reward_model(server_url, item, tokenizer) for item in tqdm(test_dataset)]
    score_path = result_dir / "scores.jsonl"
    save_jsonl(results, score_path)

    # compute pass@k
    results = load_jsonl(score_path)
    groups = group_results(results, args.apps)
    eval_results = [score_pass_at_k(groups, k, home_path.stem) for k in range(1, 16)]
    eval_result_path = result_dir / "passk.jsonl"
    save_jsonl(eval_results, eval_result_path)
    pprint.pp(eval_results)
