import sys
import numpy as np
from sklearn import metrics
from collections import defaultdict


def estimate_pass_at_k(
    num_samples: list[int], num_correct: list[int], k: int
) -> np.ndarray:
    """
    Estimates pass@k of each problem and returns them in an array.
    """

    def estimator(n: int, c: int, k: int) -> float:
        """
        Calculates 1 - comb(n - c, k) / comb(n, k).
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    return np.array(
        [estimator(int(n), int(c), k) for n, c in zip(num_samples, num_correct)]
    )


def pass_at_k(samples, ks: list[int]):
    # groupby taskid
    grouped = defaultdict(list)
    for sample in samples:
        grouped[sample["task_id"]].append(sample)

    num_samples, num_correct = [], []
    for task_id, group in grouped.items():
        num_samples.append(len(group))
        num_correct.append(sum(x["pass"] for x in group))
    assert len(set(num_samples)) == 1, "Groups don't have same size"

    results = []
    for k in ks:
        pass_k = np.mean(estimate_pass_at_k(num_samples, num_correct, k))
        results.append({"k": k, "pass@k": pass_k, "score_func": "random"})
    return results


def positive_only(item):
    return item["positive_score"]


def postive_and_negative(item):
    pos = item["positive_score"]
    neg = item["negative_score"]
    return pos / (pos + neg)


def pos_neg_filter_uncertain(item, threshold):
    pos = item["positive_score"]
    neg = item["negative_score"]
    if (pos + neg) < threshold:
        return 0
    else:
        return pos / (pos + neg)


def top_at_k(samples, ks: list[int], score_func):
    grouped = defaultdict(list)
    for sample in samples:
        grouped[sample["task_id"]].append(sample)

    num_samples, first_pass_indices = [], []
    for task_id, group in grouped.items():
        num_samples.append(len(group))
        scored_group = [(score_func(item), item) for item in group]
        sorted_group = sorted(scored_group, key=lambda x: x[0], reverse=True)

        first_pass_idx = sys.maxsize
        for idx, (_, item) in enumerate(sorted_group):
            if item["pass"]:
                first_pass_idx = idx
                break
        first_pass_indices.append(first_pass_idx)

    assert len(set(num_samples)) == 1, "Groups don't have same size"

    results = []
    for k in ks:
        top_k = sum(1 for x in first_pass_indices if x < k) / len(first_pass_indices)
        results.append({"k": k, "pass@k": top_k, "score_func": score_func.__name__})
    return results


def auroc(samples, score_func):
    y = np.array([1 if x["pass"] else 0 for x in samples])
    pred = np.array([score_func(x) for x in samples])

    fpr, tpr, thresholds = metrics.roc_curve(y, pred)
    roc_auc = metrics.auc(fpr, tpr)
    return roc_auc, fpr, tpr
