import argparse
from functools import partial
import os

from transformers import AutoTokenizer
from vllm import SamplingParams

from codecritic.dataset.genrm_prompt import THINK_MESSAGE, JUDGE_MESSAGE, JUDGE_TOEKNS
from codecritic.utils.inference import generate_worker, score_worker
from codecritic.utils.parallel import model_map
from codecritic.utils.json import load_jsonl, save_jsonl


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--sample", type=str, help="path/to/sample")
    parser.add_argument("--output", type=str, help="path/to/score")
    parser.add_argument("--reasoning", action="store_true", help="enable reasoning")
    parser.add_argument(
        "--reason_max_tokens",
        type=int,
        default=4096,
        help="maximum number of tokens allowed for the reasoning process.",
    )
    parser.add_argument(
        "--tp", type=int, default=1, help="tensor parallel"
    )
    args = parser.parse_args()

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    dataset = load_jsonl(args.sample)

    if args.reasoning:
        for item in dataset:
            item["messages"].append(THINK_MESSAGE)

        sampling_params = SamplingParams(
            n=1,
            temperature=0,
            top_p=0.95,
            max_tokens=args.reason_max_tokens,
        )

        worker = partial(
            generate_worker, model_path=args.model, sampling_params=sampling_params
        )
        dataset = model_map(worker, dataset, args.tp)

    def get_token_id(token):
        score_tokens = tokenizer.encode(token, add_special_tokens=False)
        assert len(score_tokens) == 1
        return score_tokens[0]

    positive_token = get_token_id(JUDGE_TOEKNS["positive"])
    negative_token = get_token_id(JUDGE_TOEKNS["negative"])

    for item in dataset:
        item["messages"].append(JUDGE_MESSAGE)

    worker = partial(
        score_worker,
        model_path=args.model,
        positive_token=positive_token,
        negative_token=negative_token,
    )
    dataset = model_map(worker, dataset, args.tp)

    save_jsonl(dataset, args.output)
