import argparse
import json
import requests
from tqdm import tqdm
from transformers import AutoTokenizer

from codecritic.utils.json import load_jsonl, save_jsonl


def get_rewards_from_server(server_url: str, messages: list[str]):
    """
    Gets reward scores from the API server.
    """
    headers = {"Content-Type": "application/json"}
    payload = {"query": messages}
    response = requests.post(server_url, json=payload, headers=headers)
    rewards = json.loads(response.text)["rewards"]
    return rewards


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--sample", type=str, help="path/to/sample")
    parser.add_argument("--output", type=str, help="path/to/score")
    args = parser.parse_args()

    # compute score
    dataset = load_jsonl(args.sample)
    server_url = "http://0.0.0.0:5000/get_reward"
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    for item in tqdm(dataset):
        query = tokenizer.apply_chat_template(item["messages"], tokenize=False)
        score = get_rewards_from_server(server_url, [query])[0]
        item["score"] = score

    save_jsonl(dataset, args.output)
