import argparse
from copy import deepcopy
import json
from tqdm import tqdm
import requests
from transformers import AutoTokenizer
import pprint
from pathlib import Path

from utils import load_jsonl, save_jsonl, extract_code, code_template
from utils_metric import group_results, score_pass_at_k


def get_rewards_from_server(server_url: str, messages: list[str]):
    """
    Gets reward scores from the API server.
    """
    headers = {"Content-Type": "application/json"}
    payload = {"model": "model", "messages": messages}
    response = requests.post(server_url, json=payload, headers=headers)
    rewards = json.loads(response.text)["scores"]
    return rewards


def preprocess_dataset(model_path, test_dataset, gpu_num):
    "apply chat_template and split the dataset to different gpu"
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    result = []
    for i, item in enumerate(test_dataset):
        messages = deepcopy(item["messages"])
        messages[-1]["content"] = code_template.format(
            extract_code(messages[-1]["content"])
        )
        # https://github.com/hiyouga/LLaMA-Factory/blob/a45f3f5461e2936b9e119eda2ef4d8c7a4131740/tests/data/test_template.py#L58
        # # llama factory's template should match tokenizer's `apply_chat_template`.
        item["format_str"] = [tokenizer.apply_chat_template(messages, tokenize=False)]
        result.append((item, 8000 + i % gpu_num))
    return result


def test_reward_model(item, api_port):
    server_url = f"http://0.0.0.0:{api_port}/v1/score/evaluation"
    score = get_rewards_from_server(server_url, item["format_str"])[0]
    return {
        "problem_id": item["problem_id"],
        "messages": item["messages"],
        "eval_result": item["eval_result"],
        "score": score,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--test", type=str)
    parser.add_argument("--apps", type=str)
    args = parser.parse_args()

    home_path = Path(args.model).parent
    result_dir = home_path / "eval"
    result_dir.mkdir(exist_ok=True)

    # compute score
    score_path = result_dir / "scores.jsonl"
    raw_test_dataset = load_jsonl(args.test)
    test_dataset = preprocess_dataset(args.model, raw_test_dataset, 1)
    results = [test_reward_model(*arg) for arg in tqdm(test_dataset)]
    save_jsonl(results, score_path)

    # compute pass@k
    eval_result_path = result_dir / "passk.jsonl"
    # results = load_jsonl(result_path)
    groups = group_results(results, args.apps)
    eval_results = [score_pass_at_k(groups, k, home_path.stem) for k in range(1, 16)]
    save_jsonl(eval_results, eval_result_path)
    pprint.pp(eval_results)
