from utils import load_jsonl, save_json, extract_code, read_config
from utils_dataset import mk_preference_dataset_info, mk_preference_pair, save_dataset
from nltk.metrics.distance import edit_distance
from collections import defaultdict
from itertools import product, chain
import multiprocessing
from tqdm.contrib.concurrent import process_map


def mk_problem_groups(train_dataset_path, n):
    train_dataset = load_jsonl(train_dataset_path)

    assert len(train_dataset) % n == 0
    problems = []
    for i in range(len(train_dataset) // n):
        problem = train_dataset[i * n : (i + 1) * n]
        problem_id = problem[0]["problem_id"]
        eval_results = [d["eval_result"] for d in problem]
        # filter all passed/failed problems
        if True in eval_results and False in eval_results:
            instruction = problem[0]["messages"][0]["content"]
            correct_codes, incorrect_codes = [], []
            for d in problem:
                assert d["problem_id"] == problem_id, "dataset is not sorted"
                code = extract_code(d["messages"][1]["content"])
                if d["eval_result"]:
                    correct_codes.append(code)
                else:
                    incorrect_codes.append(code)
            problems.append(
                dict(
                    problem_id=problem_id,
                    instruction=instruction,
                    correct_codes=correct_codes,
                    incorrect_codes=incorrect_codes,
                )
            )
    return problems


def calculate_edit_distances_for_problem(problem):
    local_pairs = []
    for pair in product(problem["correct_codes"], problem["incorrect_codes"]):
        # transform incorrect code to correct code
        distance = edit_distance(pair[1], pair[0])
        local_pairs.append(
            (distance, problem["problem_id"], problem["instruction"], pair)
        )
    return local_pairs


def calculate_edit_distances(problems):
    cpu_num = multiprocessing.cpu_count()
    results = process_map(
        calculate_edit_distances_for_problem,
        problems,
        max_workers=cpu_num,
        chunksize=32,
    )
    return list(chain.from_iterable(results))


def mk_edit_distance_dataset(all_pairs, k, n, is_max=True):
    """
    Top-k pairs with the maximum/minimum edit distance.
    Each problem can contribute no more than n pairs.
    Each code snippet can be used only once.
    """
    all_pairs.sort(reverse=is_max, key=lambda x: x[0])

    code_usages = defaultdict(set)
    problem_contributions = defaultdict(int)
    preference_pairs, pairs_metadata = [], []
    for distance, problem_id, instr, pair in all_pairs:
        if len(preference_pairs) >= k:
            break

        is_code_used = (pair[0] in code_usages[problem_id]) or (
            pair[1] in code_usages[problem_id]
        )
        if not is_code_used and problem_contributions[problem_id] < n:
            code_usages[problem_id].update(pair)
            problem_contributions[problem_id] += 1
            preference_pairs.append(mk_preference_pair(instr, pair[0], pair[1]))
            pairs_metadata.append(dict(problem_id=problem_id, edit_distance=distance))

    return preference_pairs, pairs_metadata


if __name__ == "__main__":
    cfg = read_config()
    problems = mk_problem_groups(
        cfg["dataset"]["train_path"], cfg["sample"]["sampling_params"]["n"]
    )

    all_edit_distance_pairs = calculate_edit_distances(problems)

    # Maximum distance
    preference_pairs, metadata = mk_edit_distance_dataset(
        all_edit_distance_pairs, 10 * 1000, 5, is_max=True
    )
    max_dataset_cfg = cfg["preference_dataset"]["max_edit_distance"]
    dataset_info = mk_preference_dataset_info(max_dataset_cfg["dataset_name"])

    save_json(metadata, max_dataset_cfg["metadata_path"])
    save_dataset(cfg["llamafactory_path"], dataset_info, preference_pairs)

    # Minimum distance
    preference_pairs, metadata = mk_edit_distance_dataset(
        all_edit_distance_pairs, 10 * 1000, 5, is_max=False
    )
    min_dataset_cfg = cfg["preference_dataset"]["min_edit_distance"]
    dataset_info = mk_preference_dataset_info(min_dataset_cfg["dataset_name"])
    save_json(metadata, min_dataset_cfg["metadata_path"])
    save_dataset(cfg["llamafactory_path"], dataset_info, preference_pairs)
