import argparse
from collections import defaultdict
from functools import partial

from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor

from codecritic.utils.json import load_jsonl, save_jsonl


client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="token-abc123",
)


def chat(messages, model):
    completion = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=8192,
        temperature=0,
    )
    response = completion.choices[0].message.content
    return response


def load_dataset(test_path):
    raw_dataset = load_jsonl(test_path)
    task_dict = defaultdict(list)
    for item in raw_dataset:
        task_dict[item["task_id"]].append(item)

    unsolvable, dataset = [], []
    for _, items in task_dict.items():
        if all([not x["pass"] for x in items]):
            for item in items:
                item["confidence"] = 0
                item["prediction"] = None
            unsolvable.extend(items)
        else:
            dataset.extend(items)
    return dataset, unsolvable


prompt_template = """
Please verify if the following code correctly solves this question:

Question: {question}

Code:
{code}

Please provide:
1. Your judgment (True if the code is correct, False if it's not)
2. Your confidence (a float number between 0.00 and 1.00, where 0.00 means you have no idea about the solution and are purely guessing, and 1.00 means you are absolutely certain your solution is correct)

Format your response exactly like this:
Judgment: [True/False]
Confidence: [0.00-1.00]
""".strip()


def preprocess_prompt(item):
    question = item["messages"][0]["content"]
    code = item["code"]
    prompt = prompt_template.format(question=question, code=code)
    return [{"role": "user", "content": prompt}]


def postprocess_response(response):
    try:
        # Skip the thinking process if present
        if '<think>' in response:
            response = response.split('</think>')[-1].strip()

        # Extract judgment and confidence using string parsing
        lines = response.strip().split('\n')
        judgment = None
        confidence = None

        for line in lines:
            line = line.strip('*')
            if line.lower().startswith('judgment:'):
                judgment_str = line.split(':', 1)[1].strip().lower()
                if judgment_str in ['true', 'false']:
                    judgment = judgment_str == 'true'
            elif line.lower().startswith('confidence:'):
                confidence_str = line.split(':', 1)[1].strip()
                try:
                    confidence = float(confidence_str)
                    # Ensure confidence is between 0 and 1
                    confidence = max(0.0, min(1.0, confidence))
                except ValueError:
                    confidence = None

        return judgment, confidence
    except Exception as e:
        # Return default values in case of any error
        print(e)
        return None, None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--testset", type=str, help="path/to/testset")
    parser.add_argument("--output", type=str, help="path/to/score")
    args = parser.parse_args()

    chat_fun = partial(chat, model=args.model)

    dataset, unsolvable = load_dataset(args.testset)
    # dataset = dataset[:4]
    prompts = list(map(preprocess_prompt, dataset))

    from tqdm.contrib.concurrent import thread_map
    #with ThreadPoolExecutor(max_workers=4) as executor:
    #    responses = executor.map(chat_fun, prompts)

    responses = thread_map(chat_fun, prompts, max_workers=8)

    for item, response in zip(dataset, responses):
        judgement, confidence = postprocess_response(response)
        item["prediction"] = judgement
        item["confidence"] = confidence
        item["response"] = response

    save_jsonl(dataset + unsolvable, args.output)
