import argparse
from itertools import chain
import os
from tqdm.contrib.concurrent import thread_map
from threading import Lock
from openai import OpenAI
import json

from codecritic.data.cov import (
    convert_preference_to_vot_prompt,
    convert_sft_to_vot_prompt,
    convert_cov_to_cov_dataset,
)
from codecritic.data.utils import save_jsonl_dataset
from codecritic.utils.json import load_json, load_jsonl


client = OpenAI(
    base_url="https://api.deepseek.com/",
    api_key=os.environ["DEEPSEEK_API_KEY"]
)


def worker(args):
    index, cov_prompt, lock = args
    completion = client.chat.completions.create(
        model="deepseek-coder",
        messages=cov_prompt["messages"],
        temperature=0,
        max_tokens=2048
    )
    content = completion.choices[0].message.content
    cov_prompt["messages"].append({"role": "assistant", "content": content})
    with lock:
        with open("tmp.jsonl", "a", encoding="utf-8") as f:
            cov_prompt["index"] = index
            f.write(json.dumps(cov_prompt) + '\n')
    return cov_prompt


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--preference_dataset", type=str)
    parser.add_argument("--sft_dataset", type=str)
    parser.add_argument("--output_dir", type=str)
    args = parser.parse_args()

    if args.preference_dataset:
        preference_dataset = load_json(args.preference_dataset)
        cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False) for x in preference_dataset]
        cov_prompts = list(chain(*cov_prompts))
    elif args.sft_dataset:
        sft_dataset = load_jsonl(args.sft_dataset)
        cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False) for x in sft_dataset]
    else:
        parser.error("preference_dataset or sft_dataset")


    lock = Lock()
    total_len = len(cov_prompts)
    thread_args = list(zip(range(total_len),
                           cov_prompts,
                           [lock] * total_len))

    covs = thread_map(worker, thread_args, max_workers=8)

    dataset = list(map(convert_cov_to_cov_dataset, covs))
    save_jsonl_dataset(dataset, args.output_dir)
