import argparse
from collections import defaultdict
from itertools import product, chain
from pathlib import Path
import os
from tqdm.contrib.concurrent import process_map
from rapidfuzz import fuzz

from codecritic.dataset.code import preprocess_code
from codecritic.utils.json import load_jsonl, save_jsonl


def group_and_filter(dataset):
    grouped = defaultdict(list)
    for sample in dataset:
        grouped[sample["task_id"]].append(sample)

    # filter groups passed/failed all testcase
    for task_id, group in grouped.items():
        passes = {x["pass"] for x in group}
        if len(passes) == 2:
            yield group


def compute_pair_similarity(group):
    correct_code_set, incorrect_code_set = set(), {''}
    correct_samples, incorrect_samples = [], []

    assert len(group) > 0
    dataset_name = group[0]["dataset"]
    task_id = group[0]["task_id"]

    for sample in group:
        code = preprocess_code(sample["code"])
        item = {
            "solution_id": sample["solution_id"],
            "code": code
        }
        if sample["pass"] and (code not in correct_code_set):
            correct_samples.append(item)
        elif (not sample["pass"]) and (code not in incorrect_code_set):
            incorrect_samples.append(item)

    results = []
    for correct, incorrect in product(correct_samples, incorrect_samples):
        score = fuzz.ratio(correct["code"], incorrect["code"])
        results.append({
            "dataset": dataset_name,
            "task_id": task_id,
            "chosen": correct["solution_id"],
            "rejected": incorrect["solution_id"],
            "similarity": score
        })
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, help="path/to/dataset")
    parser.add_argument("--output", type=str, help="path/to/output")
    args = parser.parse_args()

    dataset = load_jsonl(args.dataset)

    cache_path = Path(args.dataset + ".pairinfo")
    if not cache_path.exists():
        results = process_map(
            compute_pair_similarity,
            group_and_filter(dataset),
            max_workers=os.cpu_count(),
            chunksize=1,
        )
        pairinfo = list(chain.from_iterable(results))
        save_jsonl(pairinfo, cache_path)
    else:
        pairinfo = load_jsonl(cache_path)
        print(f"load cached similarity information from {cache_path}")


    # select pairs
    task_groups = defaultdict(list)
    for item in pairinfo:
        task_groups[item["task_id"]].append(item)

    selected_pairs = []
    for task, items in task_groups.items():
        sorted_items = sorted(items, key=lambda x: x["similarity"], reverse=True)[:2]
        selected_pairs.extend(sorted_items)

    save_jsonl(selected_pairs, args.output)
