import argparse
from pathlib import Path

from codecritic.utils.json import save_json, save_jsonl
from codecritic.utils.vllm import vllm_chatcomplete
from codecritic.sampling.sample_apps import mk_sample_prompt
from codecritic.sampling.evaluate_code import evaluate
from codecritic.sampling.sort_split_dataset import sort_and_split_dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)
    parser.add_argument("--apps", type=str)
    parser.add_argument("--output_dir", type=str)
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)

    prompts = mk_sample_prompt(args.model, args.apps)

    sampling_params = dict(n=50, temperature=0.6, max_new_tokens=2048)
    save_json(sampling_params, output_dir / "sampling_params.json")
    codes = vllm_chatcomplete(args.model, prompts, sampling_params)
    save_jsonl(codes, output_dir / "sample.jsonl")

    dataset = evaluate(codes, args.apps)
    save_jsonl(dataset, output_dir / "dataset.jsonl")

    train, test, min_test = sort_and_split_dataset(dataset, sampling_params["n"])
    save_jsonl(train, output_dir / "train.jsonl")
    save_jsonl(test, output_dir / "test.jsonl")
    save_jsonl(min_test, output_dir / "min_test.jsonl")
