# copy from codeparrot/apps_metric/utils.py
# https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py

import json
import multiprocessing
import numpy as np
from tqdm.contrib.concurrent import process_map

from codecritic.evaluation.apps_exec import run_test

TIMEOUT = 10


def check_correctness(sample, generation, timeout, debug=False):
    """Check correctness of code generation with a global timeout.
    The global timeout is to catch some extreme/rare cases not handled by the timeouts
    inside `run_test`"""

    def _temp_run(sample, generation, debug, result):
        result.append(run_test(sample, test=generation, debug=debug))

    manager = multiprocessing.Manager()
    result = manager.list()
    p = multiprocessing.Process(
        target=_temp_run, args=(sample, generation, debug, result)
    )
    p.start()
    p.join(timeout=timeout + 1)
    if p.is_alive():
        p.kill()
    if not result:
        in_outs = json.loads(sample["input_output"])
        # consider that all tests failed
        result = [[-1 for i in range(len(in_outs["inputs"]))]]
        if debug:
            print(f"global timeout")
    return result[0]


def test_generation(args, debug=False):
    apps_item, sample = args
    code = sample["code"]

    curr_res = [-2]
    try:
        curr_res = check_correctness(apps_item, code, timeout=TIMEOUT, debug=debug)
        if debug:
            print(f"\nSuccessful compilation of task {code}!")
        fixed = []
        for e in curr_res:
            if isinstance(e, np.ndarray):
                e = e.item(0)
            if isinstance(e, np.bool_):
                e = bool(e)
            fixed.append(e)
        curr_res = fixed
        if not np.all(curr_res):
            if debug:
                print(curr_res)
                print(f"Results were not True for all test cases")
    except Exception as e:
        if debug:
            print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
    finally:
        assert isinstance(curr_res, list)
        problem_result = np.asarray(curr_res)

    return {
        "task_id": sample["task_id"],
        "solution_id": sample["solution_id"],
        "pass": bool(np.all(problem_result > 0)),
        "timeout": bool(-1 in curr_res),
        "compilerr": bool(-2 in curr_res),
    }


def evaluate_code_samples(code_samples, apps):
    args = []
    for sample in code_samples:
        task_id = sample["task_id"]
        split, idx = task_id.split('-')
        args.append((apps[split][int(idx)], sample))

    cpu_num = multiprocessing.cpu_count()
    # chunksize = max(len(code_samples) // (cpu_num * 5), 1)
    chunksize = 10000
    # TODO performance?
    results = process_map(
        test_generation, args, max_workers=cpu_num, chunksize=chunksize
    )
    return results


def evaluate(code_samples, apps):
    """
    There are some strange bugs in apps evaluation that cannot be reproduced.
    The observable issue is that the same code will yield different 'eval_result' values.
    Typically, the test framework may encounter an exception or decide that the code has timed out unreasonably.

    This function is an ugly workaround to address this problem:
    Run twice to verify if the result is consistent.
    The 'loop_num' parameter controls the number of times the function will be retried until the test framework obtains a consistent result.
    """
    all_results = []
    for _ in range(2):
        results = evaluate_code_samples(code_samples, apps)
        all_results.append(results)

    final_results = []
    for lst in map(list, zip(*all_results)):
        assert len(set(x["task_id"] for x in lst)) == 1, "Mismatched task_id"
        assert len(set(x["solution_id"] for x in lst)) == 1, "Mismatched solution_id"

        task_id, solution_id = lst[0]["task_id"], lst[0]["solution_id"]
        
        if all(x["compilerr"] for x in lst):
            is_pass = False
        else:
            # If there is a compilation error in any of the multiple runs, treat it as an exception and remove it.
            lst = [x for x in lst if not x["compilerr"]]
            is_pass = all(x["pass"] for x in lst)

        final_results.append({
            "task_id": task_id,
            "solution_id": solution_id,
            "pass": is_pass
        })

    for sample, is_pass in zip(code_samples, final_results):
        assert sample["task_id"] == is_pass["task_id"], "Mismatched task_id"
        assert sample["solution_id"] == is_pass["solution_id"], "Mismatched solution_id"
        sample["pass"] = is_pass["pass"]

    return code_samples
