from datasets import load_dataset
import json
from utils import read_config, save_jsonl
from transformers import AutoTokenizer


def mk_prompt(doc) -> str:
    prompt = "Write Python code to solve competitive programming problems in a markdown code block."

    starter_code = None if len(doc["starter_code"]) == 0 else doc["starter_code"]
    try:
        input_outpout = json.loads(doc["input_output"])
        fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
    except ValueError:
        fn_name = None
    prompt += "\nQUESTION:\n"
    prompt += doc["question"]
    if starter_code:
        prompt += starter_code
    if not fn_name:
        prompt += "\nUse Standard Input format"
    else:
        prompt += "\nUse Call-Based format"

    prompt += "\nPlease generate the code in a ```python markdown block, ensuring to include the closing ``` at the end."

    conversation = [{"role": "user", "content": prompt}]
    return conversation


def mk_sample_prompt(model_path, apps_path, output_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    prompts = []
    for split in ["train", "test"]:
        ds = load_dataset(apps_path, split=split)
        for sample in ds:
            problem_id = split + "_" + str(sample["problem_id"])

            # Filter problems without input_output
            try:
                json.loads(sample["input_output"])
            except ValueError:
                print(f"Skipping {problem_id}: Invalid JSON in input_output")
                continue

            prompt = mk_prompt(sample)

            # Filter long prompts
            chat_text = tokenizer.apply_chat_template(prompt, tokenize=False)
            tokenized_prompt = tokenizer.encode(chat_text)
            if len(tokenized_prompt) > (4096 - 512):
                print(
                    f"Skipping {problem_id}: Token length {len(tokenized_prompt)} exceeds limit"
                )
                continue

            prompts.append(dict(problem_id=problem_id, messages=prompt))

    print(f"size of dataset: {len(prompts)}")
    save_jsonl(prompts, output_path)


if __name__ == "__main__":
    cfg = read_config()
    mk_sample_prompt(cfg["model"], cfg["apps"], cfg["sample"]["sample_prompt_path"])