import argparse
import datetime
from itertools import chain
from functools import partial
import os
from tqdm.contrib.concurrent import thread_map
from threading import Lock
from openai import OpenAI
import json

from codecritic.data.cov import (
    convert_preference_to_vot_prompt,
    convert_sft_to_vot_prompt,
    convert_cov_to_cov_dataset,
)
from codecritic.utils.json import load_json, load_jsonl, save_jsonl


client = OpenAI(
    base_url="https://api.deepseek.com/",
    api_key=os.environ["DEEPSEEK_API_KEY"]
)


def worker(index, cov_prompt, lock, tmp_path):
    completion = client.chat.completions.create(
        model="deepseek-coder",
        messages=cov_prompt["messages"],
        temperature=0,
        max_tokens=2048
    )
    content = completion.choices[0].message.content
    cov_prompt["messages"].append({"role": "assistant", "content": content})
    with lock:
        with open(tmp_path, "a", encoding="utf-8") as f:
            cov_prompt["index"] = index
            f.write(json.dumps(cov_prompt) + '\n')
    return cov_prompt


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--dataset_type", type=str, choices=["reward", "sft"])
    parser.add_argument("--out", type=str)
    parser.add_argument("--mode", type=str, choices=["train", "test"])
    args = parser.parse_args()

    if args.dataset_type == "reward":
        preference_dataset = load_json(args.dataset) # TODO change to jsonl
        cov_prompts = [convert_preference_to_vot_prompt(x, splitter=False, mode=args.mode) for x in preference_dataset]
        cov_prompts = list(chain(*cov_prompts))
    elif args.dataset_type == "sft":
        sft_dataset = load_jsonl(args.dataset)
        cov_prompts = [convert_sft_to_vot_prompt(x, splitter=False, mode=args.mode) for x in sft_dataset]
    else:
        parser.error("preference_dataset or sft_dataset")


    lock = Lock()
    tmp_path = f"log_{datetime.datetime.now().strftime('%y%m%d_%H%M')}.jsonl"
    total_len = len(cov_prompts)
    w = partial(worker, lock=lock, tmp_path=tmp_path)

    covs = thread_map(w, range(total_len), cov_prompts, max_workers=8)

    save_jsonl(covs, args.out + ".raw")
    dataset = [convert_cov_to_cov_dataset(x, args.mode) for x in covs]
    save_jsonl(dataset, args.out)
