import argparse
from collections import defaultdict
import json
import requests
from tqdm import tqdm
from transformers import AutoTokenizer

from codecritic.utils.json import load_jsonl, save_jsonl


def get_rewards_from_server(server_url: str, messages: list[str]):
    """
    Gets reward scores from the API server.
    """
    headers = {"Content-Type": "application/json"}
    payload = {"query": messages}
    response = requests.post(server_url, json=payload, headers=headers)
    rewards = json.loads(response.text)["rewards"]
    return rewards


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--testset", type=str, help="path/to/testset")
    parser.add_argument("--output", type=str, help="path/to/score")
    args = parser.parse_args()

    # compute score
    dataset = load_jsonl(args.testset)
    ds = defaultdict(list)
    for item in dataset:
        ds[item["task_id"]].append(item)

    for task_id, items in ds.items():
        if all([not x["pass"] for x in items]):
            for item in items:
                item["positive_score"] = 0

    server_url = "http://0.0.0.0:5000/get_reward"
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    for item in tqdm(dataset):
        if 'positive_score' not in item:
            query = tokenizer.apply_chat_template(item["messages"], tokenize=False)
            score = get_rewards_from_server(server_url, [query])[0]
            item["positive_score"] = score

    save_jsonl(dataset, args.output)
