import argparse
from collections import defaultdict
from functools import partial

from vllm import SamplingParams
from datasets import load_dataset

from codecritic.utils.parallel import model_map
from codecritic.utils.inference import generate_worker
from codecritic.utils.json import load_jsonl, save_jsonl
import codecritic.evaluation.apps_eval as evaluation
import codecritic.dataset.algolr_prompt as promptlib


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--dataset", type=str, help="path/to/sample")
    parser.add_argument("--pairinfo", type=str, help="path/to/pairinfo")
    parser.add_argument("--apps", type=str, help="path/to/apps")
    parser.add_argument("--output", type=str, help="path/to/score")
    parser.add_argument("--hint_level", type=str, choices=["beginner"])
    parser.add_argument(
        "--tp", type=int, default=1, help="tensor parallel"
    )
    args = parser.parse_args()

    # Step0 load dataset
    dataset = load_jsonl(args.dataset)
    pairinfo = load_jsonl(args.pairinfo)

    ds = defaultdict(dict)
    for item in dataset:
        ds[item["task_id"]][item["solution_id"]] = item

    # Step1 Generate hints
    hint_prompts = []
    for pair in pairinfo:
        task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
        chosen, rejected = ds[task_id][chosen_id], ds[task_id][chosen_id]
        prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
        hint_prompts.append(prompt)

    sampling_params = SamplingParams(
        n=1,
        temperature=0,
        top_p=0.95,
        max_tokens=2048,
    )

    worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
    hint_responses = model_map(worker, hint_prompts, args.tp)

    hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
    # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}

    save_jsonl(hints, args.output + ".hints")

    hints_dict = defaultdict(dict)
    for item in hints:
        hints_dict[item["task_id"]][item["solution_id"]] = item

    # Step2 Generate reasoning
    reason_prompts = []
    for pair in pairinfo:
        task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
        chosen, rejected = ds[task_id][chosen_id], ds[task_id][chosen_id]

        CORRECT_HINT = "The code is correct."
        # chosen
        chosen_prompt = promptlib.process_to_reason_prompt(chosen, CORRECT_HINT)
        reason_prompts.append(chosen_prompt)

        # rejected
        rejected_hints = hints_dict[task_id][rejected_id]
        rejected_prompt = promptlib.process_to_reason_prompt(rejected, rejected_hints)
        reason_prompts.append(rejected_prompt)

    sampling_params = SamplingParams(
        n=4,
        temperature=0.8,
        top_p=0.95,
        max_tokens=4096,
    )

    worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
    reason_responses = model_map(worker, reason_prompts, args.tp)

    # Step3 Verify reasoning results
    # add prompt "correct the code based the reasoning"
    # original solution_id is only an int, now change it to x_y(str) x is solution_id, y is the reason id.
    reason_id_counter = defaultdict(lambda: defaultdict(int))

    # Iterate through the list and update solution_id
    for item in reason_responses:
        task_id = item["task_id"]
        solution_id = item["solution_id"]

        reason_id = reason_id_counter[task_id][solution_id]
        item["solution_id"] = f"{solution_id}_{reason_id}"
        reason_id_counter[task_id][solution_id] += 1

        promptlib.remove_hint(item)

        item["messages"].append({"role": "user", "content": promptlib.get_debug_prompt()})

    sampling_params = SamplingParams(
        n=1,
        temperature=0,
        top_p=0.95,
        max_tokens=2048,
    )

    worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
    verify_responses = model_map(worker, reason_responses, args.tp)
    print("verify response size: {}".format(len(verify_responses)))

    # postprocess verify_response.
    # filter the judgement that are not consistent with ground truth.
    verify_passed = []
    for item in verify_responses:
        conclusion, code = promptlib.extract_conclusion_and_code(item["messages"][-1]["content"])
        if conclusion == item["pass"]:
            item["code"] = code
            verify_passed.append(item)

    print("verify passed (judgement consistent) size: {}".format(len(verify_passed)))

    incorrects, corrects = [], []
    for item in verify_passed:
        if not item["pass"]:
            incorrects.append(item)
        else:
            corrects.append(item)

    # need a list of dict {"task_id": str, "solution_id": str(unique index), "code": ...}
    apps = load_dataset(args.apps)
    fixed_incorrects = evaluation.evaluate(incorrects, apps)

    # filter that code is not correct.
    verify_passed = [x for x in fixed_incorrects if x["pass"]] + corrects
    print("verify passed (judgement consistent) size: {}".format(len(verify_passed)))

    # Step4 Remove hints and Reformat to a SFT dataset
    # extract reasoning sets

    sft = []
    for item in verify_passed:
        line = {
            "dataset": item["dataset"],
            "task_id": item["task_id"],
            "solution_id": item["solution_id"],
            "question": item["messages"][:1],
            "response": item["messages"][1:2],
        }
        sft.append(line)

    print("Size of sft dataset: {}".format(len(sft)))
    save_jsonl(sft, args.output)
