from tqdm import tqdm
from utils import load_jsonl, save_jsonl, read_config
from utils_metric import group_results, score_pass_at_k
from step4_test_reward_model import preprocess_dataset, test_reward_model


if __name__ == "__main__":
    cfg = read_config(["orm_testmodel"])
    orm_test_model = cfg["orm_testmodel"]

    raw_test_dataset = load_jsonl(cfg["dataset"]["minimal_test_path"])

    model_path = cfg["orm"][orm_test_model]["model_path"]
    test_dataset = preprocess_dataset(model_path, raw_test_dataset, 1)

    results = [test_reward_model(*arg) for arg in tqdm(test_dataset)]
    save_jsonl(results, cfg["orm"][orm_test_model]["minimal_test_score_path"])
    # results = load_jsonl(result_path)

    groups = group_results(results, cfg["apps"])
    eval_results = [score_pass_at_k(groups, k, orm_test_model) for k in range(1, 16)]
    save_jsonl(eval_results, cfg["orm"][orm_test_model]["eval_result_path"])
    print(eval_results)
