import argparse
from pathlib import Path
import os
import json
from functools import partial
from collections import defaultdict

from datasets import load_dataset
from vllm import SamplingParams
from transformers import AutoTokenizer


from codecritic.dataset.apps import mk_prompt
from codecritic.dataset.code import extract_code
from codecritic.evaluation.apps_eval import evaluate
from codecritic.utils.inference import generate_worker
from codecritic.utils.parallel import model_map
from codecritic.utils.json import load_jsonl, save_jsonl


def transform_to_prompt(apps, tokenizer):
    prompts = []
    for split in ["train", "test"]:
        dataset = apps[split]
        for item in dataset:
            task_id = split + "-" + str(item["problem_id"])
            try:
                json.loads(item["input_output"])
            except ValueError:
                print(f"Skipping {task_id}: Invalid JSON in input_output")
                continue

            prompt = mk_prompt(item)

            # Filter long prompts
            tokenized_question = tokenizer.apply_chat_template(prompt, tokenize=True)
            length = len(tokenized_question)
            if length > 2048:
                print(f"Skipping {task_id}: Token length {length} exceeds limit")
                continue

            prompts.append(
                {
                    "dataset": "apps-" + item["difficulty"],
                    "task_id": task_id,
                    "messages": prompt,
                }
            )

    return prompts


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, help="path/to/model")
    parser.add_argument("--apps", type=str, help="path/to/apps")
    parser.add_argument("--train", type=str, help="path/to/train")
    parser.add_argument("--test", type=str, help="path/to/test")
    parser.add_argument(
        "--tp", type=int, default=1, help="tensor parallel"
    )
    args = parser.parse_args()
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    apps = load_dataset(args.apps)

    train_raw_path = Path(args.train + ".raw")
    test_raw_path = Path(args.test + ".raw")

    if not (train_raw_path.exists() and test_raw_path.exists()):
        tokenizer = AutoTokenizer.from_pretrained(args.model)
        dataset = transform_to_prompt(apps, tokenizer)

        # sampling
        sampling_params = SamplingParams(
            n=50,
            temperature=0.8,
            top_p=0.95,
            max_tokens=2048,
        )

        worker = partial(
            generate_worker, model_path=args.model, sampling_params=sampling_params
        )
        dataset = model_map(worker, dataset, args.tp)

        # postprocess
        grouped = defaultdict(list)
        for sample in dataset:
            grouped[sample["task_id"]].append(sample)

        def is_in_test(task_id):
            split, idx = task_id.split("-")
            idx = int(idx)
            if split == "test":
                for start, end in [(0, 300), (3000, 3100), (4000, 4100)]:
                    if start <= idx < end:
                        return True
            return False

        trainset, testset = [], []
        for task_id, group in grouped.items():
            target = testset if is_in_test(task_id) else trainset
            for idx, sample in enumerate(group):
                sample["solution_id"] = idx
                sample["code"] = extract_code(sample["messages"][-1]["content"])
                target.append(sample)

        save_jsonl(trainset, train_raw_path)
        save_jsonl(testset, test_raw_path)
    else:
        trainset = load_jsonl(train_raw_path)
        testset = load_jsonl(test_raw_path)

    print("Start evaluation")
    trainset = evaluate(trainset, apps)
    testset = evaluate(testset, apps)

    save_jsonl(trainset, args.train)
    save_jsonl(testset, args.test)
