# copy from codeparrot/apps_metric/utils.py
# https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py

import json
import multiprocessing
import numpy as np
from datasets import load_dataset
from tqdm.contrib.concurrent import process_map

from step1_apps_test import run_test
from utils import extract_code, read_config, load_jsonl, save_jsonl

TIMEOUT = 10


def check_correctness(sample, generation, timeout, debug=False):
    """Check correctness of code generation with a global timeout.
    The global timeout is to catch some extreme/rare cases not handled by the timeouts
    inside `run_test`"""
    def _temp_run(sample, generation, debug, result):
        result.append(run_test(sample, test=generation, debug=debug))

    manager = multiprocessing.Manager()
    result = manager.list()
    p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
    p.start()
    p.join(timeout=timeout + 1)
    if p.is_alive():
        p.kill()
    if not result:
        in_outs = json.loads(sample["input_output"])
        # consider that all tests failed
        result = [[-1 for i in range(len(in_outs["inputs"]))]]
        if debug:
            print(f"global timeout")
    return result[0]


def test_generation(args, debug=False):
    apps_item, code_sample = args
    message = code_sample["messages"][-1]
    assert message["role"] == "assistant"
    code = extract_code(message["content"])

    curr_res = [-2]
    try:
        curr_res = check_correctness(apps_item, code, timeout=TIMEOUT, debug=debug)
        if debug:
            print(f"\nSuccessful compilation of task {code}!")
        fixed = []
        for e in curr_res:
            if isinstance(e, np.ndarray):
                e = e.item(0)
            if isinstance(e, np.bool_):
                e = bool(e)
            fixed.append(e)
        curr_res = fixed
        if not np.all(curr_res):
            if debug:
                print(curr_res)
                print(f"Results were not True for all test cases")
    except Exception as e:
        if debug:
            print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
    finally:
        assert isinstance(curr_res, list)
        problem_results = np.asarray(curr_res)

    code_sample["eval_result"] = bool(np.all(problem_results > 0))

    return code_sample


def evaluate_code_samples(code_samples: list, dataset_path: str):
    apps_eval = load_dataset(dataset_path)

    def get_apps_item(item):
        problem_id = item["problem_id"]
        split, idx = problem_id.split('_')
        # get corresponding samples from APPS dataset
        return apps_eval[split][int(idx)]

    args = [(get_apps_item(sample), sample) for sample in code_samples]

    cpu_num = multiprocessing.cpu_count()
    # TODO `chunksize` affects performance a lot
    results = process_map(test_generation, args, max_workers=cpu_num, chunksize=1000)

    return results


def evaluate(code_sample_path, dataset_path, output_path):
    code_samples = load_jsonl(code_sample_path)
    results  = evaluate_code_samples(code_samples, dataset_path)
    save_jsonl(results, output_path)


if __name__ == "__main__":
    cfg = read_config()
    evaluate(
        cfg["sample"]["sample_result_path"],
        cfg["apps"],
        cfg["evaluate"]["evaluate_result_path"],
    )
