import argparse
from itertools import product, chain
from codecritic.utils.json import load_jsonl, save_jsonl


def mk_preference_pair(ds, pair):
    task_id = pair["task_id"]
    chosen = ds[task_id][pair["chosen"]]
    rejected = ds[task_id][pair["rejected"]]
    return {
        "messages": chosen["messages"][:1],
        "chosen": chosen["messages"][1:],
        "rejected": rejected["messages"][1:],
        "meta_pairinfo": pair
    }


def mk_sft(ds, pair):
    dataset_name = pair["dataset"]
    task_id = pair["task_id"]
    chosen = ds[task_id][pair["chosen"]]
    rejected = ds[task_id][pair["rejected"]]
    # TODO add judgement response
    return [
        {
            "question": chosen["messages"][:1],
            "response": chosen["messages"][1:],
            "dataset": dataset_name,
            "task_id": task_id,
            "solution_id": chosen
        },
        {
            "question": rejected["messages"][:1],
            "response": rejected["messages"][1:],
            "dataset": dataset_name,
            "task_id": task_id,
            "solution_id": rejected
        }
    ]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, help="path/to/dataset")
    parser.add_argument("--pairs", type=str, help="path/to/selected_pairs")
    parser.add_argument("--output", type=str, help="path/to/output")
    args = parser.parse_args()

    dataset = load_jsonl(args.dataset)
    selected_pairs = load_jsonl(args.pairs)

    if args.format == "sft":
        sft_ds = list(chain.from_iterable([mk_sft(dataset, pair) for pair in selected_pairs]))
        save_jsonl(sft_ds, args.output)
    elif args.format == "reward":
        reward_ds = [mk_preference_pair(dataset, pair) for pair in selected_pairs]
        save_jsonl(reward_ds, args.output)
    else:
        raise NotImplementedError(f"Unknown format: {args.format}")
