import argparse
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import pprint
from pathlib import Path

from openai import OpenAI

from codecritic.utils.json import load_jsonl, save_jsonl
import codecritic.dataset.distill_prompt as promptlib


client = OpenAI(api_key="sk-36862826208b48c68c789746fc98de9b", base_url="https://api.deepseek.com")


def generate_completion(prompt):
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=prompt["gen_prompt"],
        max_tokens=1024,
        temperature=0.7,
        stream=False
    )

    response_content = response.choices[0].message.content
    prompt["cot"] = response_content
    return prompt


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, help="path/to/sample")
    parser.add_argument("--pairinfo", type=str, help="path/to/pairinfo")
    parser.add_argument("--output", type=str, help="path/to/score")
    args = parser.parse_args()

    dataset = load_jsonl(args.dataset)
    pairinfo = load_jsonl(args.pairinfo)

    ds = defaultdict(dict)
    for item in dataset:
        ds[item["task_id"]][item["solution_id"]] = item

    verify_prompts = []
    for pair in pairinfo:
        task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
        chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id]
        prompts = promptlib.mk_distillation_messages(chosen, rejected)
        verify_prompts.extend(prompts)

    pprint.pp(verify_prompts[:2])

    # verify_prompts = verify_prompts[:8]
    raw_response_path = Path(args.output + ".raw_response")
    if raw_response_path.exists():
        generated_responses = load_jsonl(raw_response_path)
        verify_prompts = verify_prompts[len(generated_responses):]
    else:
        generated_responses = []

    print("generated:", len(generated_responses), "rest", len(verify_prompts))

    with ThreadPoolExecutor(max_workers=8) as executor:
        raw_responses = list(executor.map(generate_completion, verify_prompts))

    generated_responses.extend(raw_responses)
    save_jsonl(generated_responses, args.output + ".raw_response")

    outputs = []
    for res in generated_responses:
        is_valid, clean_response, verification = promptlib.postprocess_result(res["cot"])
        print(is_valid, verification)
        if is_valid and (verification == res["pass"]):
            row = {
                "task_id": res["task_id"],
                "solution_id": res["solution_id"],
                "question": res["train_prompt"],
                "response": [{"role": "user", "content": clean_response}]
            }
            outputs.append(row)

    save_jsonl(outputs, args.output)
